import math from typing import List, Optional, Union import torch import torch._prims_common as utils from torch import Tensor from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table from torch._ops import OpOverload from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND from torch._prims_common import ( check, corresponding_complex_dtype, corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, IntLike, make_contiguous_strides_for, ) from torch._prims_common.wrappers import out_wrapper from torch._refs import _broadcast_shapes from torch._subclasses.fake_tensor import check_no_bool_index_tensors from torch.utils._pytree import tree_map aten = torch.ops.aten _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta") def register_meta(op): def wrapper(fn): def register(op): _add_op_to_registry(meta_table, op, fn) tree_map(register, op) return fn return wrapper def toRealValueType(dtype): from_complex = { torch.complex32: torch.half, torch.cfloat: torch.float, torch.cdouble: torch.double, } return from_complex.get(dtype, dtype) @register_meta([aten._fft_c2c.default, aten._fft_c2c.out]) @out_wrapper() def meta_fft_c2c(self, dim, normalization, forward): assert self.dtype.is_complex return self.new_empty(self.size()) @register_meta([aten._fft_r2c.default, aten._fft_r2c.out]) @out_wrapper() def meta_fft_r2c(self, dim, normalization, onesided): assert self.dtype.is_floating_point output_sizes = list(self.size()) if onesided: last_dim = dim[-1] last_dim_halfsize = (output_sizes[last_dim] // 2) + 1 output_sizes[last_dim] = last_dim_halfsize return self.new_empty( output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) ) @register_meta(aten.randperm.generator_out) def meta_randperm(n, *, generator=None, out): assert out.ndim == 1 and out.size(0) == n return out @register_meta(aten.randint.default) def meta_randint( high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None ): return torch.empty( size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory ) @register_meta(aten.randint.low) def meta_randint_low( low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None ): return torch.empty( size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory ) @register_meta(aten.rand.default) def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None): return torch.empty( size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory ) @register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) @out_wrapper() def meta_fft_c2r(self, dim, normalization, lastdim): assert self.dtype.is_complex output_sizes = list(self.size()) output_sizes[dim[-1]] = lastdim return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) @register_meta(aten.copy_.default) def meta_copy_(self, src, non_blocking=False): return self def inferUnsqueezeGeometry(tensor, dim): result_sizes = list(tensor.size()) result_strides = list(tensor.stride()) new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] result_sizes.insert(dim, 1) result_strides.insert(dim, new_stride) return result_sizes, result_strides @register_meta(aten.unsqueeze_.default) def meta_unsqueeze_(self, dim): dim = maybe_wrap_dim(dim, self.dim() + 1) g_sizes, g_strides = inferUnsqueezeGeometry(self, dim) self.as_strided_(g_sizes, g_strides) return self # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py @register_meta(aten.index_select.default) def meta_index_select(self, dim, index): result_size = list(self.size()) if self.dim() > 0: result_size[dim] = index.numel() return self.new_empty(result_size) @register_meta(aten.index_select.out) def meta_index_select_out(self, dim, index, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.index_select(self, dim, index)) @register_meta([aten.max.default, aten.max.unary_out]) @out_wrapper() def meta_max(self): return self.new_empty(()) @register_meta(aten.max.dim) def meta_max_dim(self, dim, keepdim=False): dim = utils.reduction_dims(self.shape, (dim,)) output_shape = _compute_reduction_shape(self, dim, keepdim) return ( self.new_empty(output_shape), self.new_empty(output_shape, dtype=torch.long), ) @register_meta([aten.min.default]) def meta_min(self): return self.new_empty(()) @register_meta(aten.angle.default) def meta_angle(self): if self.is_complex(): result_dtype = corresponding_real_dtype(self.dtype) else: _, result_dtype = elementwise_dtypes( self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) return torch.empty_like(self, dtype=result_dtype) @register_meta(aten.angle.out) def meta_angle_out(self, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.angle(self)) # From aten/src/ATen/native/LinearAlgebraUtils.h def squareCheckInputs(self: Tensor, f_name: str): assert ( self.dim() >= 2 ), f"{f_name}: The input tensor must have at least 2 dimensions." assert self.size(-1) == self.size( -2 ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" # From aten/src/ATen/native/LinearAlgebraUtils.h def checkFloatingOrComplex( t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True ): dtype = t.dtype check( t.is_floating_point() or t.is_complex(), lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}", ) if allow_low_precision_dtypes: check( dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}", ) # From aten/src/ATen/native/LinearAlgebraUtils.h def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"): check( A.dim() >= 2, lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", ) def checkUplo(uplo: str): uplo_uppercase = uplo.upper() assert ( len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L" ), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}" # @register_meta(aten.linalg_eigh.default) def meta_linalg_eigh(self, uplo="L"): squareCheckInputs(self, "linalg_eigh") checkUplo(uplo) real_dtype = toRealValueType(self.dtype) assert self.dim() >= 2 values = self.new_empty(self.shape, dtype=real_dtype) values.transpose_(-2, -1) vectors = self.new_empty(self.shape[:-1]) return (values, vectors) # From aten/src/ATen/native/BatchLinearAlgebra.cpp @register_meta(aten.linalg_cholesky_ex.default) def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False): squareCheckInputs(A, "linalg.cholesky") checkFloatingOrComplex(A, "linalg.cholesky") A_shape = A.shape ndim = len(A_shape) # L L_strides = make_contiguous_strides_for(A_shape, False) L = A.new_empty(A_shape) L.as_strided_(A_shape, L_strides) # infos infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32) return L, infos # From aten/src/ATen/native/BatchLinearAlgebra.cpp @register_meta(aten.linalg_inv_ex.default) def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False): squareCheckInputs(A, "linalg.inv_ex") checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False) L = A.new_empty(A.shape) L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) infos = A.new_empty(A.shape[:-2], dtype=torch.int32) return L, infos # From aten/src/ATen/native/BatchLinearAlgebra.cpp # NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml @register_meta(aten._linalg_svd.default) def _linalg_svd_meta( A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None ): checkIsMatrix(A, "linalg.svd") checkFloatingOrComplex(A, "linalg.svd") batch_dims = list(A.shape[:-2]) m = A.shape[-2] n = A.shape[-1] k = min(m, n) if compute_uv: U_shape = batch_dims + [m, m if full_matrices else k] U = A.new_empty(U_shape) U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False)) V_shape = batch_dims + [n if full_matrices else k, n] V = A.new_empty(V_shape) # TODO: need to distinguish cuSOLVER case? (see original code) V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=False)) else: # doesn't matter U = A.new_empty([0]) V = A.new_empty([0]) # S is always real, even when A is complex. S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype)) return U, S, V # From aten/src/ATen/native/LinearAlgebra.cpp @register_meta(aten._linalg_det.default) def _linalg_det_meta(A): squareCheckInputs(A, "linalg.det") checkFloatingOrComplex(A, "linalg.det") det = A.new_empty(A.shape[:-2]) LU = A.new_empty(A.shape) LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) pivots = A.new_empty(A.shape[:-1], dtype=torch.int32) return det, LU, pivots # From aten/src/ATen/native/ReflectionPad.cpp @register_meta( [aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default] ) def meta_pad2d_backward(grad_output, self, padding): dim_w = 2 dim_h = 1 dim_plane = 0 nbatch = 1 self_shape = self.shape if self.dim() == 4: nbatch = self_shape[0] dim_w += 1 dim_h += 1 dim_plane += 1 pad_l = padding[0] pad_r = padding[1] pad_t = padding[2] pad_b = padding[3] nplane = self_shape[dim_plane] input_h = self_shape[dim_h] input_w = self_shape[dim_w] output_h = input_h + pad_t + pad_b output_w = input_w + pad_l + pad_r check( output_w == grad_output.shape[dim_w], lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}", ) check( output_h == grad_output.shape[dim_h], lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}", ) return self.new_empty(self.shape) @register_meta(aten.reflection_pad2d.default) def meta_pad2d(self, padding): valid_dims = self.size(1) != 0 and self.size(2) != 0 check( (self.ndim == 3 and valid_dims) or (self.ndim == 4 and valid_dims and self.size(3) != 0), lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}", ) if self.ndim == 4: nbatch, nplane, input_h, input_w = self.shape else: nbatch = 1 nplane, input_h, input_w = self.shape pad_l, pad_r, pad_t, pad_b = padding output_h = input_h + pad_t + pad_b output_w = input_w + pad_l + pad_r if self.ndim == 3: return self.new_empty((nplane, output_h, output_w)) else: return self.new_empty((nbatch, nplane, output_h, output_w)) @register_meta([aten.bernoulli.default, aten.bernoulli.out]) @out_wrapper() def meta_bernoulli(self, *, generator=None): # https://github.com/pytorch/pytorch/issues/88612 return torch.empty_like(self).contiguous() @register_meta(aten.bernoulli_.float) def meta_bernoulli_(self, p=0.5, generator=None): return self @register_meta(aten.bernoulli.p) def meta_bernoulli_p(self, p=0.5, generator=None): # https://github.com/pytorch/pytorch/issues/88612 return torch.empty_like(self).contiguous() @register_meta(aten._fused_moving_avg_obs_fq_helper.default) def meta__fused_moving_avg_obs_fq_helper( self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False, ): check( ch_axis < self.dim(), lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()", ) mask = torch.empty_like(self, dtype=torch.bool) return (torch.empty_like(self), mask) def dot_check(self, other): check( self.dim() == 1 and other.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", ) @register_meta(aten.dot.default) def meta_dot(self, tensor): dot_check(self, tensor) return self.new_empty(()) @register_meta([aten.mm.default]) def meta_mm(a, b): check(a.dim() == 2, lambda: "a must be 2D") check(b.dim() == 2, lambda: "b must be 2D") N, M1 = a.shape M2, P = b.shape check(M1 == M2, lambda: "a and b must have same reduction dim") return a.new_empty(N, P) def _compute_reduction_shape(self, dims, keepdim): if keepdim: return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim)) return utils.compute_reduction_output_shape(self.shape, dims) # FakeTensors (meta tensors with a device) will report device as meta # when running meta kernels. Here, access the "fake device" of FakeTensor if it # exists so meta kernels which have diverge per device will be more # accurate when run with FakeTensors def device_hint(tensor) -> "str": if isinstance(tensor, torch._subclasses.FakeTensor): return tensor.fake_device.type else: return "cuda" # default to cuda def calc_conv_nd_return_shape( input_tensor: torch.Tensor, weight: torch.Tensor, stride: Union[List[int], int], padding: Union[List[int], int], dilation: Union[List[int], int], is_transposed: bool, groups: int, output_padding: Optional[Union[List[int], int]] = None, ): def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html Args: ln: length of the dimension p: padding in that dim d: dilation in that dim k: kernel size in that dim s: stride in that dim Returns: The output length """ return (ln + 2 * p - d * (k - 1) - 1) // s + 1 def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: """ Formula to apply to calculate the length of some dimension of the output if transposed convolution is used. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html Args: ln: length of the dimension p: padding in that dim d: dilation in that dim k: kernel size in that dim s: stride in that dim op: output padding in that dim Returns: The output length """ return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 kernel_size = weight.shape[2:] dims = input_tensor.shape[2:] if is_transposed: out_channels = groups * weight.shape[1] else: out_channels = weight.shape[0] if weight.shape[1] * groups != input_tensor.shape[1]: raise RuntimeError("Invalid channel dimensions") ret_shape = [input_tensor.shape[0], out_channels] if isinstance(stride, IntLike): stride = [stride] * len(dims) elif len(stride) == 1: stride = [stride[0]] * len(dims) if isinstance(padding, IntLike): padding = [padding] * len(dims) elif len(padding) == 1: padding = [padding[0]] * len(dims) if isinstance(dilation, IntLike): dilation = [dilation] * len(dims) elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) output_padding_list: Optional[List[int]] = None if output_padding: if isinstance(output_padding, IntLike): output_padding_list = [output_padding] * len(dims) elif len(output_padding) == 1: output_padding_list = [output_padding[0]] * len(dims) else: output_padding_list = output_padding for i in range(len(dims)): # If output_padding is present, we are dealing with a transposed convolution if output_padding_list: ret_shape.append( _formula_transposed( dims[i], padding[i], dilation[i], kernel_size[i], stride[i], output_padding_list[i], ) ) else: ret_shape.append( _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) ) return ret_shape def is_channels_last(ten): return torch._prims_common.suggest_memory_format(ten) == torch.channels_last @register_meta(aten.convolution.default) def meta_conv( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, ): def pick_memory_format(): if device_hint(input_tensor) == "cuda": if is_channels_last(input_tensor) or is_channels_last(weight): return torch.channels_last else: if is_channels_last(input_tensor): return torch.channels_last if input_tensor.is_contiguous(memory_format=torch.contiguous_format): return torch.contiguous_format elif input_tensor.is_contiguous(memory_format=torch.preserve_format): return torch.preserve_format shape_out = calc_conv_nd_return_shape( input_tensor, weight, stride, padding, dilation, is_transposed, groups, output_padding if is_transposed else None, ) out = input_tensor.new_empty(shape_out) out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] return out if torch._C.has_mkldnn: _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library( "mkldnn", "IMPL", "Meta" ) def pick_mkldnn_conv_memory_format(input_tensor, weight): if weight.is_mkldnn: return torch.channels_last if is_channels_last(input_tensor) or is_channels_last(weight): return torch.channels_last if input_tensor.is_contiguous(memory_format=torch.contiguous_format): return torch.contiguous_format elif input_tensor.is_contiguous(memory_format=torch.preserve_format): return torch.preserve_format @register_meta(torch.ops.mkldnn._convolution_pointwise.default) def meta_mkldnn_convolution_default( input_tensor, weight, bias, padding, stride, dilation, groups, attr, scalars, algorithm, ): shape_out = calc_conv_nd_return_shape( input_tensor, weight, stride, padding, dilation, False, groups, [] ) out = input_tensor.new_empty(shape_out) out_memory_format = torch.channels_last out = out.to(memory_format=out_memory_format) # type: ignore[call-overload] return out @register_meta(torch.ops.mkldnn._convolution_pointwise.binary) def meta_mkldnn_convolution_binary( input_tensor, other, weight, bias, padding, stride, dilation, groups, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm, ): out = input_tensor.new_empty(other.size()) out = out.to(memory_format=torch.channels_last) # type: ignore[call-overload] return out @register_meta(torch.ops.mkldnn._convolution_pointwise_.binary) def meta_mkldnn_convolution_binary_inplace( input_tensor, other, weight, bias, padding, stride, dilation, groups, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm, ): return other @register_meta(torch.ops.mkldnn._linear_pointwise.default) def meta_linear_pointwise_default( input_tensor, weight, bias, attr, scalars, algorithm ): return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0])) @register_meta(torch.ops.mkldnn._linear_pointwise.binary) def meta_linear_pointwise_binary(input_tensor, other, weight, bias, attr): out = input_tensor.new_empty(other.size()) return out if torch._C.has_mkl: _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library( "mkl", "IMPL", "Meta" ) @register_meta(torch.ops.mkl._mkl_linear) def meta_mkl_linear( input_tensor, packed_weight, orig_weight, bias, batch_size, ): return input_tensor.new_empty( (*input_tensor.shape[:-1], orig_weight.shape[0]) ) # from check_dim_size() in aten/src/ATen/TensorUtils.cpp. def check_dim_size(tensor, dim, dim_size, size): check( tensor.dim() == dim and tensor.shape[dim_size] == size, lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", ) @register_meta(aten.avg_pool2d.default) def meta_avg_pool2d( input, kernel_size, stride=(), padding=(0,), ceil_mode=False, count_include_pad=True, divisor_override=None, ): def unpack(name, val): check( len(val) in [1, 2], lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", ) H = val[0] W = H if len(val) == 1 else val[1] return H, W kH, kW = unpack("kernel_size", kernel_size) check( len(stride) in [0, 1, 2], lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) if len(stride) == 0: dH, dW = kH, kW elif len(stride) == 1: dH, dW = stride[0], stride[0] else: dH, dW = unpack("stride", stride) padH, padW = unpack("padding", padding) check( divisor_override is None or divisor_override != 0, lambda: "divisor must be not zero", ) nbatch = input.size(-4) if input.dim() == 4 else 1 nInputPlane = input.size(-3) inputHeight = input.size(-2) inputWidth = input.size(-1) outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) memory_format = utils.suggest_memory_format(input) pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, 1, 1, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format, ) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: size = [nbatch, nInputPlane, outputHeight, outputWidth] return torch.empty( size, dtype=input.dtype, device=input.device, memory_format=memory_format ) # from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h. def avg_pool2d_backward_shape_check( input, gradOutput, nbatch, kH, kW, dH, dW, padH, padW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, mem_format, ): pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, 1, 1, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, mem_format, ) ndim = input.dim() nOutputPlane = nInputPlane check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane) check_dim_size(gradOutput, ndim, ndim - 2, outputHeight) check_dim_size(gradOutput, ndim, ndim - 1, outputWidth) # Don't override the C++ registration. @register_meta(aten.avg_pool2d_backward.default) def meta_avg_pool2d_backward( gradOutput_, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, ): # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func. check( len(kernel_size) == 1 or len(kernel_size) == 2, lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints", ) kH = kernel_size[0] kW = kH if len(kernel_size) == 1 else kernel_size[1] check( len(stride) == 0 or len(stride) == 1 or len(stride) == 2, lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) dH = kH if len(stride) == 0 else stride[0] dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1] check( len(padding) == 1 or len(padding) == 2, lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints", ) padH = padding[0] padW = padH if len(padding) == 1 else padding[1] check( divisor_override is None or divisor_override != 0, lambda: "divisor must be not zero", ) input_size = input.shape nbatch = input_size[-4] if input.dim() == 4 else 1 nInputPlane = input_size[-3] inputHeight = input_size[-2] inputWidth = input_size[-1] outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) mem_format = utils.suggest_memory_format(input) avg_pool2d_backward_shape_check( input, gradOutput_, nbatch, kH, kW, dH, dW, padH, padW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, mem_format, ) return torch.empty( input_size, dtype=input.dtype, device=input.device, memory_format=mem_format ) @register_meta(aten._adaptive_avg_pool2d.default) def meta_adaptive_avg_pool2d(self, output_size): check( self.ndim == 3 or self.ndim == 4, lambda: f"Expected 3D or 4D tensor, but got {self.shape}", ) output_shape = self.shape[:-2] + tuple(output_size) memory_format = utils.suggest_memory_format(self) # need to set memory_format to preserve the memory format of the input # channel last input should have channel last output return torch.empty( output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format ) @register_meta(aten._adaptive_avg_pool3d.default) def meta_adaptive_avg_pool3d(self, output_size): check( self.ndim == 4 or self.ndim == 5, lambda: f"Expected 4D or 5D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-3] + tuple(output_size)) @register_meta(aten._adaptive_avg_pool2d_backward.default) def meta__adaptive_avg_pool2d_backward(grad_out, self): ndim = grad_out.ndim for i in range(1, ndim): check( grad_out.size(i) > 0, lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", ) check( ndim == 3 or ndim == 4, lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", ) check( self.dtype == grad_out.dtype, lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", ) return self.new_empty(self.shape) @register_meta(aten.repeat_interleave.Tensor) def meta_repeat_interleave_Tensor(repeats, output_size=None): if output_size is None: raise RuntimeError("cannot repeat_interleave a meta tensor without output_size") return repeats.new_empty(output_size) @register_meta([aten.complex.default, aten.complex.out]) @out_wrapper() def meta_complex(real, imag): assert real.dtype.is_floating_point assert imag.dtype.is_floating_point out_shape = _broadcast_shapes(real.shape, imag.shape) return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) @register_meta(aten.vdot.default) def vdot(self, other): if not self.is_complex: return torch.dot(self, other) if self.is_conj(): if other.is_conj(): return torch.vdot(other.conj(), self.conj()) else: return torch.dot(self.conj(), other) elif other.is_conj(): return torch.dot(self, other.conj()).conj() dot_check(self, other) return self.new_empty(()) # Leaving this function around because a python implementation # of indexing shape inference is useful, # but not registering it to the dispatcher because we already # get shape inference through structured kernels @register_meta(aten.index.Tensor) def meta_index_Tensor(self, indices): check_no_bool_index_tensors(aten.index.Tensor, self, indices) check(indices, lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors result: List[Optional[Tensor]] = [] for i, index in enumerate(indices): if index is not None: check( index.dtype in [torch.long, torch.int, torch.int8, torch.bool], lambda: "tensors used as indices must be long, int, byte or bool tensors", ) if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) check( k + index.ndim <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim}", IndexError, ) for j in range(index.ndim): check( index.shape[j] == self.shape[k + j], lambda: f"The shape of the mask {index.shape} at index {i} " f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", IndexError, ) result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result check( len(indices) <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", ) # expand_outplace import torch._refs as refs # avoid import cycle in mypy indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors while len(indices) < self.ndim: indices.append(None) # hasContiguousSubspace # true if all non-null tensors are adjacent # See: # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency state = 0 has_contiguous_subspace = False for index in indices: if state == 0: if index is not None: state = 1 elif state == 1: if index is None: state = 2 else: if index is not None: break else: has_contiguous_subspace = True # transposeToFront # This is the logic that causes the newly inserted dimensions to show up # at the beginning of the tensor, if they're not contiguous if not has_contiguous_subspace: dims = [] transposed_indices = [] for i, index in enumerate(indices): if index is not None: dims.append(i) transposed_indices.append(index) for i, index in enumerate(indices): if index is None: dims.append(i) transposed_indices.append(index) self = self.permute(dims) indices = transposed_indices # AdvancedIndex::AdvancedIndex # Now we can assume the indices have contiguous subspace # This is simplified from AdvancedIndex which goes to more effort # to put the input and indices in a form so that TensorIterator can # take them. If we write a ref for this, probably that logic should # get implemented before_shape: List[int] = [] after_shape: List[int] = [] replacement_shape: List[int] = [] for dim, index in enumerate(indices): if index is None: if replacement_shape: after_shape.append(self.shape[dim]) else: before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) return self.new_empty(before_shape + replacement_shape + after_shape) @register_meta([aten.convolution_backward.default]) def meta_convolution_backward( grad_output_, input_, weight_, bias_sizes_opt, stride, padding, dilation, transposed, output_padding, groups, output_mask, ): # High level logic taken from slow_conv3d_backward_cpu which should # be representative of all convolution_backward impls backend_grad_input = None backend_grad_weight = None backend_grad_bias = None if output_mask[0]: backend_grad_input = grad_output_.new_empty(input_.size()) if output_mask[1]: backend_grad_weight = grad_output_.new_empty(weight_.size()) if output_mask[2]: backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) return (backend_grad_input, backend_grad_weight, backend_grad_bias) @register_meta([aten.addbmm.default, aten.addbmm.out]) @out_wrapper() def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) self = self.expand((dim1, dim2)) check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") check( batch1.size(0) == batch2.size(0), lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", ) check( batch1.size(2) == batch2.size(1), lambda: ( f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " f"and {batch2.size(1)}x{batch2.size(2)})" ), ) check( self.size(0) == dim1 and self.size(1) == dim2, lambda: "self tensor does not match matmul output shape", ) return self.new_empty(self.size()) @register_meta(aten._cdist_forward.default) def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode in (None, 1, 2), lambda: f"possible modes: None, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape) @register_meta(aten._embedding_bag.default) def meta_embedding_bag( weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1, ): check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: check( num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1" ) num_bags -= 1 output = weight.new_empty(num_bags, weight.size(1)) MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) check( per_sample_weights.numel() == indices.numel(), lambda: ( f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " f"to be the same as indices.numel() ({indices.numel()})" ), ) def is_fast_path_index_select_scale(src, scale, output, padding_idx): return ( is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 ) def is_fast_path_index_select(src, output, padding_idx): return ( (src.dtype == torch.float or src.dtype == torch.half) and src.stride(1) == 1 and output.stride(1) == 1 and padding_idx < 0 ) def is_fast_path(src, scale, output, padding_idx): if scale is not None: return is_fast_path_index_select_scale(src, scale, output, padding_idx) else: return is_fast_path_index_select(src, output, padding_idx) if device_hint(offsets) != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: max_indices = indices.new_empty(num_bags, weight.size(1)) else: max_indices = indices.new_empty(0) else: fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum: offset2bag = offsets.new_empty(indices.size(0)) else: offset2bag = offsets.new_empty(0) bag_size = offsets.new_empty(num_bags) # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp numBags = offsets.shape[0] if mode == MODE_MAX: if include_last_offset: check( numBags >= 1, lambda: "include_last_offset: numBags should be at least 1", ) numBags -= 1 max_indices = offsets.new_empty(numBags, weight.shape[1]) else: max_indices = offsets.new_empty(bag_size.size()) return output, offset2bag, bag_size, max_indices @register_meta(aten._embedding_bag_forward_only.default) def meta_embedding_bag_forward_only(weight, indices, offsets, *args): output, offset2bag, bag_size, max_indices = meta_embedding_bag( weight, indices, offsets, *args ) if device_hint(offsets) == "cpu": bag_size = offsets.new_empty(offsets.size()) return output, offset2bag, bag_size, max_indices def _get_reduction_dtype(input, dtype, promote_int_to_long=True): # if specified, dtype takes precedence if dtype: return dtype if input.dtype.is_floating_point or input.dtype.is_complex: return input.dtype elif promote_int_to_long: return torch.long return input.dtype @register_meta([aten.nansum.default, aten.nansum.out]) @out_wrapper() def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) dims = utils.reduction_dims(input.shape, dims) output_shape = _compute_reduction_shape(input, dims, keepdim) return input.new_empty(output_shape, dtype=output_dtype) @register_meta(aten.nanmedian.default) def meta_nanmedian(input): output_shape = utils.compute_reduction_output_shape( input.shape, tuple(range(input.dim())) ) return input.new_empty(output_shape) @register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values]) @out_wrapper("values", "indices") def meta_nanmedian_dim(input, dim=-1, keepdim=False): dim = utils.reduction_dims(input.shape, (dim,)) output_shape = _compute_reduction_shape(input, dim, keepdim) return ( input.new_empty(output_shape), input.new_empty(output_shape, dtype=torch.long), ) @register_meta(aten.logical_not_.default) def meta_logical_not_(self): return self @register_meta(aten.repeat.default) def meta_repeat(self, repeats): check( len(repeats) >= self.dim(), lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", ) # Add new leading dimensions to the tensor if the # number of target dimensions is larger than the # number of source dimensions. num_new_dimensions = len(repeats) - self.dim() padded_size = (1,) * num_new_dimensions + tuple(self.shape) target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))] return self.new_empty(target_size) @register_meta(aten.zero_.default) def meta_zero_(self): return self @register_meta( [ aten.mul_.Scalar, aten.div_.Scalar, aten.mul_.Tensor, aten.div_.Tensor, aten.logical_and_.default, aten.logical_or_.default, aten.logical_xor_.default, ], ) def meta_binop_inplace(self, other): return self @register_meta( [ aten.add_.Scalar, aten.sub_.Scalar, aten.add_.Tensor, aten.sub_.Tensor, ], ) def meta_binop_inplace_alpha(self, other, alpha=1): return self @register_meta([aten.round.default, aten.round.decimals]) def meta_round(self, **kwargs): return _elementwise_meta( self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT ) @register_meta(aten.zero.default) def meta_zero(self): return self.new_empty(self.shape) @register_meta([aten.fill_.Tensor, aten.fill_.Scalar]) def meta_fill_(self, val): return self @register_meta([aten.fill.Tensor, aten.fill.Scalar]) def meta_fill(self, val): return torch.empty_like(self) @register_meta(aten.relu_.default) def meta_relu_(self): return self @register_meta(aten.index_put.default) def meta_index_put(self, indices, values, accumulate=False): return torch.empty_like(self) @register_meta(aten.masked_fill_.Scalar) def meta_masked_fill_(self, mask, value): return self @register_meta(aten.index_put_.default) def meta_index_put_(self, indices, values, accumulate=False): return self @register_meta(aten.alias.default) def meta_alias(self): return self.view(self.shape) def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") batch1_sizes = batch1.size() batch2_sizes = batch2.size() bs = batch1_sizes[0] contraction_size = batch1_sizes[2] res_rows = batch1_sizes[1] res_cols = batch2_sizes[2] output_size = (bs, res_rows, res_cols) check( batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}" f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].", ) # TODO: handle out output = batch2.new_empty(output_size) if not is_bmm and self_baddbmm is not None: check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") check( self_baddbmm.size() == output_size, lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}", ) return output @register_meta(aten.bmm.default) def meta_bmm(self, mat2): return common_meta_baddbmm_bmm(self, mat2, True) def div_rtn(x, y): q = x // y r = x % y # WARNING: explicit bool conversion here is necessary; # would be fixed by SymBool if r != 0 and (bool(r < 0) != bool(y < 0)): q -= 1 return q def pooling_output_shape_pad_lr( inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode ): outputSize = ( div_rtn( inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (stride - 1 if ceil_mode else 0), stride, ) + 1 ) if ceil_mode: if (outputSize - 1) * stride >= inputSize + pad_l: outputSize -= 1 return outputSize def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode): check(stride != 0, lambda: "stride should not be zero") check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") check( pad <= kernelSize // 2, lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}", ) return pooling_output_shape_pad_lr( inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode ) def pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format, ): ndim = input.dim() nOutputPlane = nInputPlane check( kW > 0 and kH > 0, lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}", ) check( dW > 0 and dH > 0, lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}", ) check( dilationH > 0 and dilationW > 0, lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}", ) valid_dims = input.size(1) != 0 and input.size(2) != 0 if memory_format == torch.channels_last: check( ndim == 4 and valid_dims and input.size(3) != 0, lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout" " with optional 0 dim batch size for input, but got: {input.size()}", ) else: check( (ndim == 3 and input.size(0) != 0 and valid_dims) or (ndim == 4 and valid_dims and input.size(3) != 0), lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", ) check( kW // 2 >= padW and kH // 2 >= padH, lambda: "pad should be smaller than or equal to half of kernel size, but got " f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", ) check( outputWidth >= 1 and outputHeight >= 1, lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). " f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). " "Output size is too small", ) def max_pool2d_checks_and_compute_shape( input, kernel_size, stride, padding, dilation, ceil_mode ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): check( len(val) in [1, 2], lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints", ) H = val[0] W = H if len(val) == 1 else val[1] return H, W kH, kW = unpack("kernel_size", kernel_size) check( len(stride) in [0, 1, 2], lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) if len(stride) == 0: dH, dW = kH, kW else: dH, dW = unpack("stride", stride) padH, padW = unpack("padding", padding) dilationH, dilationW = unpack("dilation", dilation) nInputPlane = input.size(-3) inputHeight = input.size(-2) inputWidth = input.size(-1) memory_format = utils.suggest_memory_format(input) if memory_format == torch.channels_last: check( input.dim() == 4, lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout", ) elif memory_format == torch.contiguous_format: check( input.dim() in [3, 4], lambda: "non-empty 3D or 4D (batch mode) tensor expected for input", ) else: check( False, lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", ) outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format, ) return nInputPlane, outputHeight, outputWidth @register_meta(aten.max_pool2d_with_indices_backward.default) def meta_max_pool2d_with_indices_backward( grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices ): nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( self, kernel_size, stride, padding, dilation, ceil_mode ) check( self.dtype == grad_output.dtype, lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", ) nOutputPlane = nInputPlane ndim = self.ndim def _check_dim_size(t): check_dim_size(t, ndim, ndim - 3, nOutputPlane) check_dim_size(t, ndim, ndim - 2, outputHeight) check_dim_size(t, ndim, ndim - 1, outputWidth) _check_dim_size(grad_output) _check_dim_size(indices) memory_format = utils.suggest_memory_format(self) return torch.empty( self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format ) @register_meta(aten.max_pool2d_with_indices.default) def meta_max_pool2d_with_indices( input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False ): nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( input, kernel_size, stride, padding, dilation, ceil_mode ) nbatch = input.size(-4) if input.dim() == 4 else 1 memory_format = utils.suggest_memory_format(input) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: size = [nbatch, nInputPlane, outputHeight, outputWidth] return ( torch.empty( size, dtype=input.dtype, device=input.device, memory_format=memory_format ), torch.empty( size, dtype=torch.int64, device=input.device, memory_format=memory_format ), ) @register_meta(aten.grid_sampler_2d_backward.default) def grid_sampler_2d_backward_meta( grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, ): input_requires_grad = output_mask[0] if input_requires_grad: grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format) else: grad_input = None grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format) return (grad_input, grad_grid) @register_meta([aten.full.default]) def full(size, fill_value, *args, **kwargs): return torch.empty(size, *args, **kwargs) @register_meta( [ aten.randint_like.default, aten.randint_like.low_dtype, aten.randn_like.default, aten.rand_like.default, aten.full_like.default, aten.ones_like.default, ] ) def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) # zeros_like is special cased to work for sparse @register_meta(aten.zeros_like.default) def zeros_like( self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None ): if layout == torch.sparse_coo: check( memory_format is None, lambda: "memory format option is only supported by strided tensors", ) res = torch.empty( 0, dtype=self.dtype if dtype is None else dtype, layout=layout, device=self.device if device is None else device, pin_memory=pin_memory, ) if self.is_sparse: res.sparse_resize_and_clear_( self.size(), self.sparse_dim(), self.dense_dim() ) else: res.sparse_resize_and_clear_(self.size(), self.dim(), 0) res._coalesced_(True) return res return aten.empty_like.default( self, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, memory_format=memory_format, ) @register_meta(aten.select.int) def meta_select(self, dim, index): ndim = self.dim() check( ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError ) dim = dim if dim >= 0 else dim + ndim size = self.size(dim) check( not (-index > size or index >= size), lambda: f"select(): index {index} out of range for tensor of size " f"{self.size()} at dimension {dim}", IndexError, ) index = index if index >= 0 else index + size new_size = list(self.size()) new_stride = list(self.stride()) new_storage_offset = self.storage_offset() + index * new_stride[dim] del new_size[dim] del new_stride[dim] return self.as_strided(new_size, new_stride, new_storage_offset) @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): return utils.clone_preserve_strides(self) @register_meta(aten.slice_scatter.default) def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1): return utils.clone_preserve_strides(self) # TODO: Deduplicate this with canonicalize_dim def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): if dim_post_expr <= 0: assert wrap_scalar dim_post_expr = 1 min = -dim_post_expr max = dim_post_expr - 1 assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})" if dim < 0: dim += dim_post_expr return dim def ensure_nonempty_size(t, dim): return 1 if t.dim() == 0 else t.shape[dim] # From aten/src/ATen/native/ScatterGatherChecks.h def gather_shape_check(self, dim, index): self_dims = max(self.dim(), 1) index_dims = max(index.dim(), 1) check( self_dims == index_dims, lambda: "Index tensor must have the same number of dimensions as input tensor", ) for i in range(self_dims): if i != dim: check( ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), lambda: f"Size does not match at dimension {i} expected index {index.shape}" + f" to be smaller than self {self.shape} apart from dimension {dim}", ) @register_meta(aten.gather.default) def meta_gather(self, dim, index, sparse_grad=False): wrapped_dim = maybe_wrap_dim(dim, self.dim()) is_index_empty = index.numel() == 0 if not is_index_empty: check( index.dtype == torch.long, lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}", ) gather_shape_check(self, wrapped_dim, index) return self.new_empty(index.shape) # From aten/src/ATen/native/TensorAdvancedIndexing.cpp def get_operator_enum(reduce_, use_new_options=False): if use_new_options: if reduce_ == "sum": return "REDUCE_ADD" elif reduce_ == "prod": return "REDUCE_MULTIPLY" elif reduce_ == "mean": return "REDUCE_MEAN" elif reduce_ == "amax": return "REDUCE_MAXIMUM" elif reduce_ == "amin": return "REDUCE_MINIMUM" check( False, lambda: "reduce argument must be either sum, prod, mean, amax or amin.", ) return else: if reduce_ == "add": return "REDUCE_ADD" elif reduce_ == "multiply": return "REDUCE_MULTIPLY" check(False, lambda: "reduce argument must be either add or multiply.") return # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_gather_dtype_check(method_name, self, index, src_opt=None): if index.numel() != 0: check( index.dtype == torch.long, lambda: f"{method_name}(): Expected dtype int64 for index", ) if src_opt is not None: check( self.dtype == src_opt.dtype, lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype", ) def ensure_nonempty_dim(dim): return max(dim, 1) # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_shape_check(self, dim, index, src_opt=None): if index.numel() == 0: return check( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), lambda: "Index tensor must have the same number of dimensions as self tensor", ) is_wrong_shape = False self_dims = ensure_nonempty_dim(self.dim()) # Check: index.size(d) <= self.size(d) for all d != dim for d in range(self_dims): index_d_size = ensure_nonempty_size(index, d) if d == dim: continue if index_d_size > ensure_nonempty_size(self, d): is_wrong_shape = True break # Check: index.size(d) <= src.size(d) for all d if src is Tensor if not is_wrong_shape and src_opt is not None: for d in range(self_dims): index_d_size = ensure_nonempty_size(index, d) if index_d_size > ensure_nonempty_size(src_opt, d): is_wrong_shape = True break if src_opt is not None: check( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), lambda: "Index tensor must have the same number of dimensions as self tensor", ) check( not is_wrong_shape, lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}", ) else: check( not is_wrong_shape, lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + f" apart from dimension {dim}", ) # From aten/src/ATen/native/TensorAdvancedIndexing.cpp def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False): wrapped_dim = maybe_wrap_dim(dim, self.dim()) scatter_gather_dtype_check("scatter", self, index, src) scatter_shape_check(self, wrapped_dim, index, src) if reduce_ is not None: # Check if we have a valid reduce operator. get_operator_enum(reduce_, use_new_options) @register_meta(aten.scatter_add.default) def meta_scatter_add(self, dim, index, src): scatter_meta_impl(self, dim, index, src, "add") return self.new_empty(self.shape) @register_meta(aten.scatter_add_) def meta_scatter_add_(self, dim, index, src): scatter_meta_impl(self, dim, index, src, "add") return self @register_meta( [ aten.scatter.src, aten.scatter.value, aten.scatter.reduce, aten.scatter.value_reduce, ] ) @out_wrapper() def meta_scatter(self, dim, index, src_or_value, reduce=None): src = src_or_value if isinstance(src_or_value, torch.Tensor) else None scatter_meta_impl(self, dim, index, src, reduce) return self.new_empty(self.shape) @register_meta( [ aten.scatter_.src, aten.scatter_.value, aten.scatter_.reduce, aten.scatter_.value_reduce, ] ) def meta_scatter_(self, dim, index, src_or_value, reduce=None): src = src_or_value if isinstance(src_or_value, torch.Tensor) else None scatter_meta_impl(self, dim, index, src, reduce) return self @register_meta( [ aten._scaled_dot_product_flash_attention, ] ) def meta__scaled_dot_product_flash( query: Tensor, key: Tensor, value: Tensor, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, ): # [Note] SDPA_flash's meta function returns incorrect Philox seed and offset: # We have added logic to torch/_dynamo/variables/torch.py # We need to check if scaled_dot_product_attention will run the flash attention # kernel and if dropout is != 0.0. If that is the case then we want dynamo # to graph break. The derivative calculation for _scaled_dot_product_flash_attention # does not function correctly with cuda graphs because the full philox state is not captured # the forward's return values. Another reason to graph break is that the the meta function # returns the wrong outputs for philox seed and offset and these values get baked into the # inductor fallback calls to the eager kernels. check( dropout_p == 0.0, lambda: f"Can only trace _scaled_dot_product_flash_attention when dropout is set to 0 but got a dropout_p of {dropout_p}.", ) batch_size = query.size(0) num_heads = query.size(1) max_seqlen_batch_q = query.size(2) head_dim = query.size(3) max_seqlen_batch_k = key.size(2) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) Nnz_q = batch_size * max_seqlen_batch_q output = torch.empty( (Nnz_q, num_heads, head_dim), dtype=query.dtype, device=query.device ) output = output.view(batch_size, max_seqlen_batch_q, num_heads, head_dim).transpose( 1, 2 ) max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16 logsumexp = torch.empty( (batch_size, num_heads, max_seqlen_q), dtype=torch.float, device=query.device, ) cumulative_sequence_length_q = torch.empty( batch_size + 1, dtype=torch.int32, device="meta" ) cumulative_sequence_length_k = torch.empty( batch_size + 1, dtype=torch.int32, device="meta" ) if return_debug_mask: blocksize_c = 128 if head_dim > 64 else 256 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) if max_seqlen_batch_k <= 128: max_seqlen_k = 128 elif max_seqlen_batch_k <= 256: max_seqlen_k = 256 debug_mask = torch.empty( (batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=query.dtype, device=query.device, ) else: debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) return ( output, logsumexp, cumulative_sequence_length_q, cumulative_sequence_length_k, max_seqlen_batch_q, max_seqlen_batch_k, 1, # Philox Seed will not be used, see note at top. 1, # Philox Offset will not be used, see note at top. debug_mask, ) @register_meta( [ aten._scaled_dot_product_flash_attention_backward, ] ) def meta__scaled_dot_product_flash_backward( grad_out: Tensor, query: Tensor, key: Tensor, value: Tensor, out: Tensor, logsumexp: Tensor, cum_seq_q: Tensor, cum_seq_k: Tensor, max_q: int, max_k: int, dropout_p: float, is_causal: bool, philox_seed: int, philox_offset: int, ): batch_size = query.size(0) num_heads = query.size(1) head_dim = query.size(3) Nnz_q = batch_size * max_q Nnz_kv = batch_size * max_k query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) query_reshaped = query.reshape(Nnz_q, num_heads, head_dim) key_reshaped = key.reshape(Nnz_kv, num_heads, head_dim) value_reshaped = value.reshape(Nnz_kv, num_heads, head_dim) grad_q = torch.empty_like(query_reshaped) grad_k = torch.empty_like(key_reshaped) grad_v = torch.empty_like(value_reshaped) grad_q = grad_q.view(batch_size, max_q, num_heads, head_dim).transpose(1, 2) grad_k = grad_k.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2) grad_v = grad_v.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2) return grad_q, grad_k, grad_v @register_meta( [ aten._scaled_dot_product_efficient_attention, ] ) def meta__scaled_dot_product_efficient( query: Tensor, key: Tensor, value: Tensor, compute_log_sumexp: bool, is_causal: bool = False, ): query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) B = query.size(0) M = query.size(1) N = key.size(1) num_heads = query.size(-2) K = query.size(-1) Kv = value.size(-1) res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 logsum_exp = torch.empty( (B, num_heads, logsumexp_dim), dtype=torch.float, device=query.device, ) res = res.transpose(1, 2) return res, logsum_exp @register_meta( [ aten._scaled_dot_product_efficient_attention_backward, ] ) def meta__scaled_dot_product_efficient_backward( grad_out: Tensor, query: Tensor, key: Tensor, value: Tensor, out: Tensor, logsumexp: Tensor, is_causal: bool = False, chunk_grad_outputs=False, ): grad_out = grad_out.transpose(1, 2) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) B = query.size(0) M = query.size(1) N = key.size(1) nH = query.size(2) K = query.size(3) grad_kv_needs_init = is_causal and N > M if chunk_grad_outputs: chunk = torch.empty((B, M, 3, nH, K), dtype=query.dtype, device=query.device) grad_q = chunk.select(2, 0) grad_k = chunk.select(2, 1) grad_v = chunk.select(2, 2) else: grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device) grad_k = ( torch.zeros(key.shape, dtype=key.dtype, device=key.device) if grad_kv_needs_init else torch.empty(key.shape, dtype=key.dtype, device=key.device) ) grad_v = ( torch.zeros(value.shape, dtype=value.dtype, device=value.device) if grad_kv_needs_init else torch.empty(value.shape, dtype=value.dtype, device=value.device) ) return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2) @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) @out_wrapper() def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True): scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) return self.new_empty(self.shape) @register_meta(aten.scatter_reduce_.two) def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) return self def multiply_integers(vs): r = 1 for v in vs: r *= v return r def upsample_common_check(input_size, output_size, num_spatial_dims): check( len(output_size) == num_spatial_dims, lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}", ) expected_input_dims = num_spatial_dims + 2 # N, C, ... check( len(input_size) == expected_input_dims, lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}", ) check( all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]), lambda: f"Input and output sizes should be greater than 0, but got " f"input size {input_size} and output size {output_size}", ) nbatch, channels = input_size[:2] return (nbatch, channels, *output_size) @register_meta(aten.upsample_nearest1d.default) def upsample_nearest1d(input, output_size, scales=None): check( input.numel() != 0 or multiply_integers(input.size()[1:]), lambda: "Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}", ) full_output_size = upsample_common_check( input.size(), output_size, num_spatial_dims=1 ) return input.new_empty(full_output_size).to( memory_format=utils.suggest_memory_format(input) ) @register_meta(aten.upsample_nearest2d.default) def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): check( input.numel() != 0 or multiply_integers(input.size()[1:]), lambda: "Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", ) full_output_size = upsample_common_check( input.size(), output_size, num_spatial_dims=2 ) output = input.new_empty(full_output_size) # convert output to correct memory format, if necessary memory_format = utils.suggest_memory_format(input) # following "heuristic: only use channels_last path when it's faster than the contiguous path" _, n_channels, _, _ = input.shape if input.device.type == "cuda" and n_channels < 4: memory_format = torch.contiguous_format output = output.contiguous(memory_format=memory_format) return output @register_meta(aten.upsample_nearest3d.default) def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None): check( input.numel() != 0 or multiply_integers(input.size()[1:]), lambda: "Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}", ) full_output_size = upsample_common_check( input.size(), output_size, num_spatial_dims=3 ) return input.new_empty(full_output_size).to( memory_format=utils.suggest_memory_format(input) ) @register_meta([aten.sort.default, aten.sort.stable]) def meta_sort(self, stable=None, dim=-1, descending=False): return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) def rnn_cell_checkSizes( input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden ): check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") check( input_gates.shape == hidden_gates.shape, lambda: f"{input_gates.shape} != {hidden_gates.shape}", ) gates_size = input_gates.size(1) if input_bias is not None: check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1") check( input_bias.numel() == gates_size, lambda: f"{input_bias.numel()} != {gates_size}", ) check( input_bias.shape == hidden_bias.shape, lambda: f"{input_bias.shape} != {hidden_bias.shape}", ) check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2") expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor check( prev_hidden.numel() == expected_prev_hidden_numel, lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})", ) check( all( x.device == input_gates.device for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] ), lambda: "expected all inputs to be same device", ) @register_meta(aten._thnn_fused_lstm_cell.default) def _thnn_fused_lstm_cell_meta( input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None ): rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx) workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format) hy = torch.empty_like(cx, memory_format=torch.contiguous_format) cy = torch.empty_like(cx, memory_format=torch.contiguous_format) return (hy, cy, workspace) @register_meta(aten._cudnn_rnn.default) def _cudnn_rnn( input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, ): is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) mini_batch = batch_sizes[0] batch_sizes_sum = input.shape[0] else: seq_length = input.shape[1] if batch_first else input.shape[0] mini_batch = input.shape[0] if batch_first else input.shape[1] batch_sizes_sum = -1 num_directions = 2 if bidirectional else 1 out_size = proj_size if proj_size != 0 else hidden_size if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: out_shape = ( [mini_batch, seq_length, out_size * num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions] ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] if cx is None: cy = torch.empty(0, device=input.device) else: cy = cx.new_empty(cell_shape) hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) reserve_shape = 0 if train else 0 reserve = input.new_empty(reserve_shape, dtype=torch.uint8) return output, hy, cy, reserve, weight_buf @register_meta(aten.mkldnn_rnn_layer.default) def mkldnn_rnn_layer( input, w0, w1, w2, w3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, ): seq_length = input.shape[1] if batch_first else input.shape[0] mini_batch = input.shape[0] if batch_first else input.shape[1] output_chanels = hidden_size out_shape = ( [mini_batch, seq_length, output_chanels] if batch_first else [seq_length, mini_batch, output_chanels] ) output = input.new_empty(out_shape) if hx_ is None: hy = torch.empty(0, device=input.device) else: hy = hx_.new_empty(hx_.shape) if cx_ is None: cy = torch.empty(0, device=input.device) else: cy = cx_.new_empty(cx_.shape) workspace = torch.empty(0, device=input.device, dtype=torch.uint8) return output, hy, cy, workspace def zero_numel_check_dims(self, dim, fn_name): if self.ndim == 0: check( dim == 0 or dim == -1, lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", IndexError, ) else: check( self.size(dim) != 0, lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", IndexError, ) # From aten/src/ATen/native/ReduceOps.cpp def check_argmax_argmin(name, self, dim): if dim is not None: dim = maybe_wrap_dim(dim, self.dim()) zero_numel_check_dims(self, dim, name) else: check( self.numel() != 0, lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", ) @register_meta([aten.argmax.default, aten.argmin.default]) def argmax_argmin_meta(self, dim=None, keepdim=False): check_argmax_argmin("argmax", self, dim) dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) shape = _compute_reduction_shape(self, dims, keepdim) return self.new_empty(shape, dtype=torch.int64) @register_meta(aten.scalar_tensor.default) def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): return torch.empty( (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory ) @register_meta(aten.topk.default) def topk_meta(self, k, dim=-1, largest=True, sorted=True): # From aten/src/ATen/native/Sorting.cpp dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) check( k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1), lambda: "selected index k out of range", ) sliceSize = 1 if self.dim() == 0 else self.size(dim) check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") topKSize = list(self.shape) if len(topKSize) > 0: topKSize[dim] = k return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) legacy_contiguous_memory_format = torch.contiguous_format # From aten/src/ATen/native/cuda/RNN.cu def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace): defined_grad = grad_hy if grad_hy is not None else grad_cy check(defined_grad.dim() == 2, lambda: "") exp_size = defined_grad.size() if grad_hy is not None: check(grad_hy.size() == exp_size, lambda: "") if grad_cy is not None: check(grad_cy.size() == exp_size, lambda: "") check(cx.size() == exp_size, lambda: "") check(cy.size() == exp_size, lambda: "") check(workspace.dim() == 2, lambda: "") check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "") # From aten/src/ATen/native/cuda/RNN.cu @register_meta(aten._thnn_fused_lstm_cell_backward_impl.default) def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias): if grad_hy is None and grad_cy is None: return None, None, None checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace) grad_gates = torch.empty_like( workspace, memory_format=legacy_contiguous_memory_format ) grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format) grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None return grad_gates, grad_cx, grad_bias @register_meta(aten.pixel_shuffle.default) def meta_pixel_shuffle(self, upscale_factor): assert ( len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0 ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}" def is_channels_last(ten): return torch._prims_common.suggest_memory_format(ten) == torch.channels_last def pick_memory_format(): if is_channels_last(self): if device_hint(self) == "cuda": return torch.contiguous_format else: return torch.channels_last elif self.is_contiguous(memory_format=torch.contiguous_format): return torch.contiguous_format elif self.is_contiguous(memory_format=torch.preserve_format): return torch.preserve_format C = self.shape[-3] // (upscale_factor * upscale_factor) Hr = self.shape[-2] * upscale_factor Wr = self.shape[-1] * upscale_factor out_shape = (*self.shape[:-3], C, Hr, Wr) out = self.new_empty(out_shape) out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] return out @register_meta(aten.mkldnn_rnn_layer_backward.default) def mkldnn_rnn_layer_backward( input, weight0, weight1, weight2, weight3, hx_, cx_tmp, output, hy_, cy_, grad_output_r_opt, grad_hy_r_opt, grad_cy_r_opt, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, ): diff_x = input.new_empty(input.shape) diff_hx = hx_.new_empty(hx_.shape) diff_cx = cx_tmp.new_empty(cx_tmp.shape) diff_w1 = weight0.new_empty(weight0.shape) diff_w2 = weight1.new_empty(weight1.shape) diff_b = weight2.new_empty(weight2.shape) return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx @register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out]) @out_wrapper() def meta_bucketize(self, boundaries, *, out_int32=False, right=False): return torch.empty_like( self, dtype=torch.int32 if out_int32 else torch.int64 ).contiguous() # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs import torch._refs.nn.functional import torch._refs.special def activate_meta(): activate_meta_table = {} # For a given op, we pick the most specific decomp function from # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd for type in ["meta", "post_autograd", "pre_autograd"]: registry = global_decomposition_table[type] for opo in registry: if opo not in activate_meta_table: activate_meta_table[opo] = registry[opo] for op_overload, fn in activate_meta_table.items(): assert isinstance(op_overload, OpOverload) op_overload.py_impl(torch._C.DispatchKey.Meta)(fn) if torch._C._dispatch_has_kernel_for_dispatch_key( op_overload.name(), "CompositeImplicitAutograd" ): # Internally, we shouldn't be registering meta kernels for any operators that # have CompositeImplicitAutograd kernels. # Instead, we should be letting those decompositions run, and writing meta kernels # only for the base operators. if op_overload in global_decomposition_table["meta"]: raise RuntimeError( f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't " "register meta function for it. Instead, we should let the decomposition run and write " "meta kernels for the base operators." ) pass elif op_overload.is_view: # Attempting to register a python meta kernel for a view operator. # We shouldn't do this, because the output will report as not having aliased storages. # All view ops have meta kernels in C++ today, so we should use those instead. pass elif op_overload.name() in { "aten::empty_strided", # causing infinite recursion, test_meta.py "aten::clone", # causing infinite recursion "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950 "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950 "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950 "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950 "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950 }: pass else: if "mkldnn::" in op_overload.name(): _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn) elif "mkl::" in op_overload.name(): _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn) else: _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) activate_meta()