123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- import torch
- from ..utils import has_triton
- if has_triton():
- import triton
- class _conv1x1:
- @staticmethod
- def _call(
- x,
- w,
- bias,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- ):
- # Q: should we check x, w, bias dtypes?
- device = x.device
- # input shapes
- shape_x = x.shape
- shape_w = w.shape
- shape_bias = bias.shape if bias is not None else None
- # indicies for the layout
- xn, xc, xh, xw = 0, 1, 2, 3
- yn, yc, yh, yw = 0, 1, 2, 3
- wn, wc, wh, ww = 0, 1, 2, 3
- # out_channel, in_channel, kernel_height, kernel_width
- kernel_size = [shape_w[wh], shape_w[ww]]
- input_size = [shape_x[xh], shape_x[xw]]
- assert (
- not shape_bias or shape_bias[0] == shape_w[wn]
- ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
- in_channel = shape_w[wc] * groups
- assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
- assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
- assert (
- shape_x[xc] == in_channel
- ), f"in_channel did not match {shape_x[xc]} != {in_channel}"
- assert (
- len(stride)
- == len(padding)
- == len(dilation)
- == len(output_padding)
- == len(kernel_size)
- == len(input_size)
- )
- # output shape
- shape_y = [0] * 4
- shape_y[yn] = shape_x[xn]
- shape_y[yc] = shape_w[wn]
- shape_y[yh] = (
- input_size[0]
- + 2 * padding[0]
- - dilation[0] * (kernel_size[0] - 1)
- - 1
- + stride[0]
- ) // stride[0] + 2 * output_padding[0]
- shape_y[yw] = (
- input_size[1]
- + 2 * padding[1]
- - dilation[1] * (kernel_size[1] - 1)
- - 1
- + stride[1]
- ) // stride[1] + 2 * output_padding[1]
- BATCH = shape_x[xn]
- IN_C = shape_x[xc]
- # IN_H = shape_x[xh]
- # IN_W = shape_x[xw]
- KERNEL_N = shape_w[wn]
- KERNEL_H = shape_w[wh]
- KERNEL_W = shape_w[ww]
- OUT_H = shape_y[yh]
- OUT_W = shape_y[yw]
- assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv"
- channels_last = x.stride()[1] == 1
- if padding == (0, 0):
- # nchw -> nhwc
- x = x.permute(0, 2, 3, 1)
- # select every stride's element (for stride > 1)
- x = x[:, :: stride[0], :: stride[1], :]
- # 2d matrix
- mat_x = x.reshape(-1, IN_C)
- # 2d matrix
- mat_w = w.view(KERNEL_N, IN_C)
- mat_w = mat_w.permute(1, 0)
- # 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N)
- mat_y = triton.ops.matmul(mat_x, mat_w)
- # mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,)
- y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N)
- if bias is not None:
- y += bias
- # convert back to the original layout of y
- # nhwc -> nchw
- y = y.permute(0, 3, 1, 2)
- if not channels_last:
- y = y.to(memory_format=torch.contiguous_format)
- return y
- else:
- y = torch.empty(
- (shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]),
- device=device,
- dtype=x.dtype,
- )
- if channels_last:
- y = y.to(memory_format=torch.channels_last)
- # y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype)
- # convert x to channel-last layout;
- # don't care w layout since kernel size is 1
- x = x.permute(0, 2, 3, 1)
- # select every stride"s element (for stride > 1)
- x = x[:, :: stride[0], :: stride[1], :]
- # 2d matrix
- mat_x = x.view(-1, IN_C)
- # 2d matrix
- mat_w = w.view(KERNEL_N, IN_C)
- mat_w = mat_w.permute(1, 0)
- # 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N)
- mat_y = triton.ops.matmul(mat_x, mat_w)
- mat_y = mat_y.view(
- BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N
- )
- # consider padding > 0
- if bias is not None:
- y[
- :,
- padding[0] : OUT_H - padding[0],
- padding[1] : OUT_W - padding[1],
- :,
- ] = (
- mat_y + bias
- )
- y[:, : padding[0], :, :] = bias
- y[:, :, : padding[1], :] = bias
- y[:, OUT_H - padding[0] :, :, :] = bias
- y[:, :, OUT_W - padding[1] :, :] = bias
- else:
- y[
- :,
- padding[0] : OUT_H - padding[0],
- padding[1] : OUT_W - padding[1],
- :,
- ] = mat_y
- y[:, : padding[0], :, :] = 0
- y[:, :, : padding[1], :] = 0
- y[:, OUT_H - padding[0] :, :, :] = 0
- y[:, :, OUT_W - padding[1] :, :] = 0
- # convert back to the original layout of y
- # nhwc -> nchw
- y = y.permute(0, 3, 1, 2)
- return y
- @staticmethod
- def forward(
- x,
- w,
- bias,
- stride=(1, 1),
- padding=(0, 0),
- dilation=(1, 1),
- transposed=False,
- output_padding=(0, 0),
- groups=1,
- ):
- if groups != 1:
- print(f"Do not support groups = {groups}")
- return
- if transposed:
- print("Do not support transposed")
- return _conv1x1._call(
- x,
- w,
- bias,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- )
- conv1x1 = _conv1x1.forward
|