123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import torch
- from ..lowering import register_lowering
- from ..select_algorithm import (
- autotune_select_algorithm,
- ExternKernelChoice,
- TritonTemplate,
- )
- from ..utils import ceildiv as cdiv, use_triton_template
- from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
- aten = torch.ops.aten
- def bmm_grid(b, m, n, meta):
- return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
- bmm_template = TritonTemplate(
- name="bmm",
- grid=bmm_grid,
- source=r"""
- {{def_kernel("A", "B")}}
- M = {{size("A", -2)}}
- N = {{size("B", -1)}}
- K = {{size("A", -1)}}
- stride_aq = {{stride("A", 0)}}
- stride_am = {{stride("A", 1)}}
- stride_ak = {{stride("A", 2)}}
- stride_bq = {{stride("B", 0)}}
- stride_bk = {{stride("B", 1)}}
- stride_bn = {{stride("B", 2)}}
- # based on triton.ops.matmul
- pid = tl.program_id(0)
- grid_m = (M + BLOCK_M - 1) // BLOCK_M
- grid_n = (N + BLOCK_N - 1) // BLOCK_N
- # re-order program ID for better L2 performance
- width = GROUP_M * grid_n
- group_id = pid // width
- group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
- pid_m = group_id * GROUP_M + (pid % group_size)
- pid_n = (pid % width) // (group_size)
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
- rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
- rk = tl.arange(0, BLOCK_K)
- idx_q = tl.program_id(1) # batch dimension for BMM
- A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
- B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
- for k in range(K, 0, -BLOCK_K):
- if EVEN_K:
- a = tl.load(A)
- b = tl.load(B)
- else:
- a = tl.load(A, mask=rk[None, :] < k, other=0.)
- b = tl.load(B, mask=rk[:, None] < k, other=0.)
- acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
- A += BLOCK_K * stride_ak
- B += BLOCK_K * stride_bk
- # rematerialize rm and rn to save registers
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- idx_q = tl.program_id(1) # batch dimension for BMM
- idx_m = rm[:, None]
- idx_n = rn[None, :]
- mask = (idx_m < M) & (idx_n < N)
- # inductor generates a suffix
- {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
- """,
- )
- aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
- aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
- @register_lowering(aten.bmm)
- def tuned_bmm(mat1, mat2, *, layout=None):
- m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
- # options to tune from
- choices = [aten_bmm.bind((mat1, mat2), layout)]
- if use_triton_template(layout):
- for config in mm_configs():
- choices.append(
- bmm_template.generate(
- (mat1, mat2),
- layout,
- **mm_options(config, k, layout),
- )
- )
- return autotune_select_algorithm(choices, [mat1, mat2], layout)
- # Don't register this since it is slower than decomposing it
- # @register_lowering(aten.baddbmm)
- def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
- m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
- # options to tune from
- choices = [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
- if use_triton_template(layout):
- for config in mm_configs():
- choices.append(
- bmm_template.generate(
- (inp, mat1, mat2),
- layout,
- **mm_options(config, k, layout),
- prefix_args=1,
- epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
- )
- )
- return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
|