import torch from ..utils import has_triton if has_triton(): import triton import triton.language as tl from .autotune import conv_heuristics from .utils import _unpack @conv_heuristics() @triton.jit def _kernel_delta_x_hwc( x, w, y, # stride of tensor stride_xn, stride_xc, stride_xh, stride_xw, stride_wn, stride_wc, stride_wh, stride_ww, stride_yn, stride_yc, stride_yh, stride_yw, stride_biasn, # pointer inc for x delta_xh_ptr, delta_xw_ptr, delta_xc_ptr, # Tensor dimensions BATCH, IN_C, IN_H, IN_W, KERNEL_N, KERNEL_H, KERNEL_W, OUT_H, OUT_W, # parameters of conv stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, output_padding_h, output_padding_w, groups, # Metaparameters ACC_TYPE: tl.constexpr, CONV1X1_NHWC: tl.constexpr, # blocks in different dimension BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # reduction tiling parameter for matmul BLOCK_K: tl.constexpr, # Super-blocking for better L2 peformance GROUP_H: tl.constexpr, ): """ each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y """ # ----------------------------------------------------------- # Map program ids `pid` to the block of y it should compute. pid_nhw = tl.program_id(0) pid_k = tl.program_id(1) # offset for output y off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) off_y_n = off_y_nhw // (OUT_H * OUT_W) off_y_hw = off_y_nhw % (OUT_H * OUT_W) off_y_h = off_y_hw // OUT_W + output_padding_h off_y_w = off_y_hw % OUT_W + output_padding_w # offset for the initial ptr for x off_x_n = off_y_n off_x_h = off_y_h * stride_h - padding_h off_x_w = off_y_w * stride_w - padding_w off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw off_x_crs = tl.arange(0, BLOCK_K) CRS = IN_C * KERNEL_H * KERNEL_W # load inc ptr of x, upade x_ptrs if not CONV1X1_NHWC: delta_xh_ptrs = delta_xh_ptr + off_x_crs delta_xw_ptrs = delta_xw_ptr + off_x_crs delta_xc_ptrs = delta_xc_ptr + off_x_crs delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0) delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0) delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0) off_x_crs_unpacked = ( delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc ) x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] else: x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] delta_xh = 0 delta_xw = 0 mask_x = ( (off_x_n < BATCH)[:, None] & (off_x_crs < CRS)[None, :] & (off_x_h[:, None] + delta_xh[None, :] >= 0) & (off_x_h[:, None] + delta_xh[None, :] < IN_H) & (off_x_w[:, None] + delta_xw[None, :] >= 0) & (off_x_w[:, None] + delta_xw[None, :] < IN_W) ) # offset for the inital ptr for w off_w_crs = tl.arange(0, BLOCK_K) off_w_k = off_y_k w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] # ------ load x ------ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) # ------ load w ------ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) # ----------------------------------------------------------- # allocate accumulator acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for crs in range(0, CRS, BLOCK_K): # ------ matrix multiplication ------ acc += tl.dot(matrix_x, matrix_w) # ------ update ptrs ------ w_ptrs += BLOCK_K # load inc ptr of x, upade x_ptrs off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) if not CONV1X1_NHWC: delta_xh_ptrs += BLOCK_K delta_xw_ptrs += BLOCK_K delta_xc_ptrs += BLOCK_K delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0) delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0) delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0) off_x_crs_unpacked = ( delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc ) x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] else: x_ptrs += BLOCK_K mask_x = ( (off_x_n < BATCH)[:, None] & (off_x_crs < CRS)[None, :] & (off_x_h[:, None] + delta_xh[None, :] >= 0) & (off_x_h[:, None] + delta_xh[None, :] < IN_H) & (off_x_w[:, None] + delta_xw[None, :] >= 0) & (off_x_w[:, None] + delta_xw[None, :] < IN_W) ) mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] # ------ prefetch ------ # ------ load x ------ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) # ------ load w ------ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) acc = acc.to(y.dtype.element_ty) # rematerialize -- this saves some registers # offset for output y off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) off_y_n = off_y_nhw // (OUT_H * OUT_W) off_y_hw = off_y_nhw % (OUT_H * OUT_W) # consider output padding off_y_h = off_y_hw // OUT_W + output_padding_h off_y_w = off_y_hw % OUT_W + output_padding_w # y ptrs in the block of [BLOCK_M, BLOCK_N] y_ptrs = ( y + off_y_n[:, None] * stride_yn + off_y_h[:, None] * stride_yh + off_y_w[:, None] * stride_yw + off_y_k[None, :] * stride_yc ) # out-of-bounds check mask_y = ( (off_y_n < BATCH)[:, None] & (off_y_h < OUT_H + output_padding_h)[:, None] & (off_y_w < OUT_W + output_padding_w)[:, None] & (off_y_k < KERNEL_N)[None, :] ) tl.store(y_ptrs, acc, mask=mask_y) return @conv_heuristics() @triton.jit def _kernel_delta_x( x, w, y, # stride of tensor stride_xn, stride_xc, stride_xh, stride_xw, stride_wn, stride_wc, stride_wh, stride_ww, stride_yn, stride_yc, stride_yh, stride_yw, stride_biasn, # pointer inc for x delta_x_ptr, # Tensor dimensions BATCH, IN_C, IN_H, IN_W, KERNEL_N, KERNEL_H, KERNEL_W, OUT_H, OUT_W, # parameters of conv stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, output_padding_h, output_padding_w, groups, # Metaparameters ACC_TYPE: tl.constexpr, CONV1X1_NHWC: tl.constexpr, # blocks in different dimension BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # reduction tiling parameter for matmul BLOCK_K: tl.constexpr, # Super-blocking for better L2 peformance GROUP_H: tl.constexpr, ): """ each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y """ # ----------------------------------------------------------- # Map program ids `pid` to the block of y it should compute. pid_nhw = tl.program_id(0) pid_k = tl.program_id(1) # offset for output y off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) off_y_n = off_y_nhw // (OUT_H * OUT_W) off_y_hw = off_y_nhw % (OUT_H * OUT_W) off_y_h = off_y_hw // OUT_W + output_padding_h off_y_w = off_y_hw % OUT_W + output_padding_w # offset for the initial ptr for x off_x_n = off_y_n off_x_h = off_y_h * stride_h - padding_h off_x_w = off_y_w * stride_w - padding_w off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw off_x_crs = tl.arange(0, BLOCK_K) CRS = IN_C * KERNEL_H * KERNEL_W # load inc ptr of x, upade x_ptrs if not CONV1X1_NHWC: delta_x_ptrs = delta_x_ptr + off_x_crs off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS) x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] else: x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] mask_x = ( (off_x_n < BATCH) & (off_x_h >= 0) & (off_x_h < IN_H) & (off_x_w >= 0) & (off_x_w < IN_W) )[:, None] & (off_x_crs < CRS)[None, :] # offset for the inital ptr for w off_w_crs = tl.arange(0, BLOCK_K) off_w_k = off_y_k w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] # ------ load x ------ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) # ------ load w ------ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) # ----------------------------------------------------------- # allocate accumulator acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for crs in range(0, CRS, BLOCK_K): # ------ matrix multiplication ------ acc += tl.dot(matrix_x, matrix_w) # ------ update ptrs ------ w_ptrs += BLOCK_K # load inc ptr of x, upade x_ptrs if not CONV1X1_NHWC: delta_x_ptrs += BLOCK_K off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) off_x_crs_unpacked = tl.load( delta_x_ptrs, mask=off_x_crs < CRS, other=0 ) x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] else: off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) x_ptrs += BLOCK_K mask_x = ( (off_x_n < BATCH) & (off_x_h >= 0) & (off_x_h < IN_H) & (off_x_w >= 0) & (off_x_w < IN_W) )[:, None] & (off_x_crs < CRS)[None, :] mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] # ------ prefetch ------ # ------ load x ------ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) # ------ load w ------ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) acc = acc.to(y.dtype.element_ty) # rematerialize -- this saves some registers # offset for output y off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) off_y_n = off_y_nhw // (OUT_H * OUT_W) off_y_hw = off_y_nhw % (OUT_H * OUT_W) # consider output padding off_y_h = off_y_hw // OUT_W + output_padding_h off_y_w = off_y_hw % OUT_W + output_padding_w # y ptrs in the block of [BLOCK_M, BLOCK_N] y_ptrs = ( y + off_y_n[:, None] * stride_yn + off_y_h[:, None] * stride_yh + off_y_w[:, None] * stride_yw + off_y_k[None, :] * stride_yc ) # out-of-bounds check mask_y = ( (off_y_n < BATCH)[:, None] & (off_y_h < OUT_H + output_padding_h)[:, None] & (off_y_w < OUT_W + output_padding_w)[:, None] & (off_y_k < KERNEL_N)[None, :] ) tl.store(y_ptrs, acc, mask=mask_y) return class _conv: kernel = _kernel_delta_x_hwc # for the contigous order of w ptr, what"s the corresponding # ptr changes for x in a sliding window @staticmethod def _delta_x_ptr_hwc( IN_C, KERNEL_H, KERNEL_W, dilation_h, dilation_w, stride_wc, stride_wh, stride_ww, stride_xc, stride_xh, stride_xw, device, ): # get the order of axes in w, innermost dimension outward stride_w_3d = [stride_wc, stride_wh, stride_ww] order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__) window_size = IN_C * KERNEL_H * KERNEL_W r_window = torch.arange(0, window_size, 1, device=device) window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W]) window_unpack_c = window_unpack[order[0]] window_unpack_h = window_unpack[order[1]] window_unpack_w = window_unpack[order[2]] r_dilation_h = dilation_h * window_unpack_h r_dilation_w = dilation_w * window_unpack_w r_inc = window_unpack_c # delta_x = ( # r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc # ) # return delta_x return ( r_dilation_h, r_dilation_w, r_inc, ) @staticmethod def _delta_x_ptr( IN_C, KERNEL_H, KERNEL_W, dilation_h, dilation_w, stride_wc, stride_wh, stride_ww, stride_xc, stride_xh, stride_xw, device, ): # get the order of axes in w, innermost dimension outward stride_w_3d = [stride_wc, stride_wh, stride_ww] order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__) window_size = IN_C * KERNEL_H * KERNEL_W r_window = torch.arange(0, window_size, 1, device=device) window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W]) window_unpack_c = window_unpack[order[0]] window_unpack_h = window_unpack[order[1]] window_unpack_w = window_unpack[order[2]] r_dilation_h = dilation_h * window_unpack_h r_dilation_w = dilation_w * window_unpack_w r_inc = window_unpack_c delta_x = ( r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc ) return delta_x @staticmethod def _call( x, w, bias, stride, padding, dilation, transposed, output_padding, groups, ): # Q: should we check x, w, bias dtypes? device = x.device # input shapes shape_x = x.shape shape_w = w.shape shape_bias = bias.shape if bias is not None else None # indicies for the layout xn, xc, xh, xw = 0, 1, 2, 3 yn, yc, yh, yw = 0, 1, 2, 3 wn, wc, wh, ww = 0, 1, 2, 3 # out_channel, in_channel, kernel_height, kernel_width kernel_size = [shape_w[wh], shape_w[ww]] input_size = [shape_x[xh], shape_x[xw]] assert ( not shape_bias or shape_bias[0] == shape_w[wn] ), f"bias shape did not match{shape_bias} != {shape_w[wn]}" in_channel = shape_w[wc] * groups assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups" assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups" assert ( shape_x[xc] == in_channel ), f"in_channel did not match {shape_x[xc]} != {in_channel}" assert ( len(stride) == len(padding) == len(dilation) == len(output_padding) == len(kernel_size) == len(input_size) ) # output shape shape_y = [0] * 4 shape_y[yn] = shape_x[xn] shape_y[yc] = shape_w[wn] shape_y[yh] = ( input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + stride[0] ) // stride[0] + 2 * output_padding[0] shape_y[yw] = ( input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + stride[1] ) // stride[1] + 2 * output_padding[1] BATCH = shape_x[xn] IN_C = shape_x[xc] IN_H = shape_x[xh] IN_W = shape_x[xw] KERNEL_N = shape_w[wn] KERNEL_H = shape_w[wh] KERNEL_W = shape_w[ww] OUT_H = shape_y[yh] OUT_W = shape_y[yw] # allocate output y = torch.empty(shape_y, device=device, dtype=x.dtype) # get strides for tensors stride_x = x.stride() stride_w = w.stride() stride_bias = bias.stride() if shape_bias else None stride_biasn = stride_bias[0] if stride_bias else None # output layout should be the same as x if stride_x[xc] < stride_x[xh] and stride_x[xc] < stride_x[xw]: y = y.to(memory_format=torch.channels_last) stride_y = y.stride() # allocate tmp # WINDOW_SIZE = KERNEL_H * KERNEL_W * IN_C # tmp_x = torch.empty((BATCH * OUT_H * OUT_W, WINDOW_SIZE), device=device, dtype=x.dtype) # tmp_w = torch.empty((WINDOW_SIZE, KERNEL_N), device=device, dtype=w.dtype) # accumulator types ACC_TYPE = ( tl.float32 if x.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ) # if stride_x[xc] == 1 and stride_x > 1 and stride_y > 1: CONV1X1_NHWC = False if stride_x[xc] == 1 and KERNEL_H == 1 and KERNEL_W == 1: CONV1X1_NHWC = True # do we need delta x ptr for h, w, c dimension each or not DELTA_X_PTR_HWC = ( False if ( (padding[0] == 0 and padding[1] == 0) or (KERNEL_H == 1 and KERNEL_W == 1) ) else True ) if not CONV1X1_NHWC: if DELTA_X_PTR_HWC: delta_xh, delta_xw, delta_xc = _conv._delta_x_ptr_hwc( IN_C, KERNEL_H, KERNEL_W, dilation[0], dilation[1], stride_w[wc], stride_w[wh], stride_w[ww], stride_x[xc], stride_x[xh], stride_x[xw], device, ) else: delta_x = _conv._delta_x_ptr( IN_C, KERNEL_H, KERNEL_W, dilation[0], dilation[1], stride_w[wc], stride_w[wh], stride_w[ww], stride_x[xc], stride_x[xh], stride_x[xw], device, ) else: delta_x = None delta_xh, delta_xw, delta_xc = None, None, None # launch kernel, 2-dim, batch*h*w, kernel def grid(META): return ( triton.cdiv(BATCH * OUT_H * OUT_W, META["BLOCK_M"]), triton.cdiv(KERNEL_N, META["BLOCK_N"]), ) # conv1x1 or padding==0 if CONV1X1_NHWC or not DELTA_X_PTR_HWC: _kernel_delta_x[grid]( x, w, y, # stride nchw for x,w,y tensor stride_x[xn], stride_x[xc], stride_x[xh], stride_x[xw], stride_w[wn], stride_w[wc], stride_w[wh], stride_w[ww], stride_y[yn], stride_y[yc], stride_y[yh], stride_y[yw], stride_biasn, # pointer inc for x delta_x, # Tensor dimensions BATCH, IN_C, IN_H, IN_W, KERNEL_N, KERNEL_H, KERNEL_W, OUT_H, OUT_W, # conv parameters stride[0], stride[1], padding[0], padding[1], dilation[0], dilation[1], output_padding[0], output_padding[1], groups, # Metaparameters ACC_TYPE=ACC_TYPE, CONV1X1_NHWC=CONV1X1_NHWC, # BLOCK_M=128, # BLOCK_N=32, # BLOCK_K=32, GROUP_H=1, ) # need to know ptr update for each dimension to check if # the sliding window is out of bounds else: # kernel = _kernel_delta_x_hwc _kernel_delta_x_hwc[grid]( x, w, y, # stride nchw for x,w,y tensor stride_x[xn], stride_x[xc], stride_x[xh], stride_x[xw], stride_w[wn], stride_w[wc], stride_w[wh], stride_w[ww], stride_y[yn], stride_y[yc], stride_y[yh], stride_y[yw], stride_biasn, # pointer inc for x delta_xh, delta_xw, delta_xc, # Tensor dimensions BATCH, IN_C, IN_H, IN_W, KERNEL_N, KERNEL_H, KERNEL_W, OUT_H, OUT_W, # conv parameters stride[0], stride[1], padding[0], padding[1], dilation[0], dilation[1], output_padding[0], output_padding[1], groups, # Metaparameters ACC_TYPE=ACC_TYPE, CONV1X1_NHWC=CONV1X1_NHWC, # BLOCK_M=128, # BLOCK_N=32, # BLOCK_K=32, GROUP_H=1, ) if bias is not None: if len(bias.shape) == 1: bias = bias.reshape([1, bias.shape[0], 1, 1]) y += bias return y @staticmethod def forward( x, w, bias, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, ): if groups != 1: print(f"Do not support groups = {groups}") return if transposed: print("Do not support transposed") return _conv._call( x, w, bias, stride, padding, dilation, transposed, output_padding, groups, ) conv = _conv.forward