_shape_functions.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. from typing import List, Any, Optional, Union, Dict, Callable, Tuple
  2. import math
  3. number = Union[int, float]
  4. # flake8: noqa
  5. ###
  6. # There are generated files that depend on this file
  7. # To re-generate, please run from the root of the repo:
  8. # python torchgen/shape_functions/gen_jit_shape_functions.py
  9. # How to test:
  10. # After regenerating files, compile PyTorch.
  11. # Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
  12. # If you have enabled opinfo testing for the op, also run:
  13. # python test/test_ops_jit.py TestJitCPU::test_variant_consistency_jit_[FAILING_OP]_cpu_float32
  14. # to reproduce errors from opinfo tests.
  15. # Example PR: https://github.com/pytorch/pytorch/pull/80860/files
  16. ####
  17. import torch
  18. def broadcast(a: List[int], b: List[int]):
  19. dimsA = len(a)
  20. dimsB = len(b)
  21. ndim = max(dimsA, dimsB)
  22. expandedSizes: List[int] = []
  23. for i in range(ndim):
  24. offset = ndim - 1 - i
  25. dimA = dimsA - 1 - offset
  26. dimB = dimsB - 1 - offset
  27. sizeA = a[dimA] if (dimA >= 0) else 1
  28. sizeB = b[dimB] if (dimB >= 0) else 1
  29. if sizeA != sizeB and sizeA != 1 and sizeB != 1:
  30. # TODO: only assertion error is bound in C++ compilation right now
  31. raise AssertionError(
  32. "The size of tensor a {} must match the size of tensor b ("
  33. "{}) at non-singleton dimension {}".format(sizeA, sizeB, i)
  34. )
  35. expandedSizes.append(sizeB if sizeA == 1 else sizeA)
  36. return expandedSizes
  37. def broadcast_three(a: List[int], b: List[int], c: List[int]):
  38. return broadcast(broadcast(a, b), c)
  39. def broadcast_one_three(a: List[int], b: Any, c: List[int]):
  40. return broadcast(a, c)
  41. def adaptive_avg_pool2d(self: List[int], out: List[int]):
  42. assert len(out) == 2
  43. assert len(self) == 3 or len(self) == 4
  44. for i in range(1, len(self)):
  45. assert self[i] != 0
  46. shape: List[int] = []
  47. for i in range(0, len(self) - 2):
  48. shape.append(self[i])
  49. for elem in out:
  50. shape.append(elem)
  51. return shape
  52. def _copy(self: List[int]):
  53. out: List[int] = []
  54. for elem in self:
  55. out.append(elem)
  56. return out
  57. def unary(self: List[int]):
  58. return _copy(self)
  59. def broadcast_inplace(a: List[int], b: List[int]):
  60. dimsA = len(a)
  61. dimsB = len(b)
  62. if dimsB > dimsA:
  63. raise AssertionError(
  64. "The dims of tensor b ({}) must be less than or equal to"
  65. "the dims of tensor a ({}) ".format(dimsB, dimsA)
  66. )
  67. for dimA in range(dimsA):
  68. dimB = dimsB - dimsA + dimA
  69. sizeA = a[dimA]
  70. sizeB = b[dimB] if (dimB >= 0) else 1
  71. if sizeA != sizeB and sizeB != 1:
  72. # TODO: only assertion error is bound in C++ compilation right now
  73. raise AssertionError(
  74. "The size of tensor a {} must match the size of tensor b ("
  75. "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
  76. )
  77. return _copy(a)
  78. def expand(self: List[int], sizes: List[int]):
  79. assert len(sizes) >= len(self)
  80. ndim = len(sizes)
  81. tensor_dim = len(self)
  82. if ndim == 0:
  83. return _copy(sizes)
  84. out: List[int] = []
  85. for i in range(ndim):
  86. offset = ndim - 1 - i
  87. dim = tensor_dim - 1 - offset
  88. size = self[dim] if dim >= 0 else 1
  89. targetSize = sizes[i]
  90. if targetSize == -1:
  91. assert dim >= 0
  92. targetSize = size
  93. if size != targetSize:
  94. assert size == 1
  95. size = targetSize
  96. out.append(size)
  97. return out
  98. def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
  99. return expand(self, sizes)
  100. def infer_size_impl(shape: List[int], numel: int) -> List[int]:
  101. newsize = 1
  102. infer_dim: Optional[int] = None
  103. for dim in range(len(shape)):
  104. if shape[dim] == -1:
  105. if infer_dim is not None:
  106. raise AssertionError("only one dimension can be inferred")
  107. infer_dim = dim
  108. elif shape[dim] >= 0:
  109. newsize *= shape[dim]
  110. else:
  111. raise AssertionError("invalid shape dimensions")
  112. if not (
  113. numel == newsize
  114. or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
  115. ):
  116. raise AssertionError("invalid shape")
  117. out = _copy(shape)
  118. if infer_dim is not None:
  119. out[infer_dim] = numel // newsize
  120. return out
  121. def numel(sizes: List[int]):
  122. numel = 1
  123. for elem in sizes:
  124. numel *= elem
  125. return numel
  126. def view(self: List[int], sizes: List[int]):
  127. return infer_size_impl(sizes, numel(self))
  128. def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
  129. return view(self, sizes)
  130. def sum_mean_dim(self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any):
  131. out: List[int] = []
  132. if opt_dims is None or len(opt_dims) == 0:
  133. dims: List[int] = list(range(len(self)))
  134. else:
  135. dims = opt_dims
  136. for idx in range(len(self)):
  137. is_mean_dim: bool = False
  138. for reduce_dim in dims:
  139. if idx == maybe_wrap_dim(reduce_dim, len(self)):
  140. is_mean_dim = True
  141. if is_mean_dim:
  142. if keep_dim:
  143. out.append(1)
  144. else:
  145. out.append(self[idx])
  146. return out
  147. def max_dim(self: List[int], dim: int, keep_dim: bool):
  148. out = sum_mean_dim(self, [dim], keep_dim, None)
  149. return out, out
  150. # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
  151. def div_rtn(x: int, y: int):
  152. return x // y
  153. def pooling_output_shape_pad_lr(
  154. inputSize: int,
  155. kernelSize: int,
  156. pad_l: int,
  157. pad_r: int,
  158. stride: int,
  159. dilation: int,
  160. ceil_mode: bool,
  161. ):
  162. outputSize = (
  163. div_rtn(
  164. inputSize
  165. + pad_l
  166. + pad_r
  167. - dilation * (kernelSize - 1)
  168. - 1
  169. + (stride - 1 if ceil_mode else 0),
  170. stride,
  171. )
  172. + 1
  173. )
  174. if ceil_mode:
  175. if (outputSize - 1) * stride >= inputSize + pad_l:
  176. outputSize = outputSize - 1
  177. return outputSize
  178. def pooling_output_shape(
  179. inputSize: int,
  180. kernelSize: int,
  181. pad_l: int,
  182. stride: int,
  183. dilation: int,
  184. ceil_mode: bool,
  185. ):
  186. assert stride != 0, "stride should not be zeero"
  187. return pooling_output_shape_pad_lr(
  188. inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
  189. )
  190. def pool2d_shape_check(
  191. input: List[int],
  192. kH: int,
  193. kW: int,
  194. dH: int,
  195. dW: int,
  196. padH: int,
  197. padW: int,
  198. dilationH: int,
  199. dilationW: int,
  200. nInputPlane: int,
  201. inputHeight: int,
  202. inputWidth: int,
  203. outputHeight: int,
  204. outputWidth: int,
  205. ):
  206. ndim = len(input)
  207. nOutputPlane = nInputPlane
  208. assert kW > 0 and kH > 0
  209. assert dW > 0 and dH > 0
  210. assert dilationH > 0 and dilationW > 0
  211. valid_dims = input[1] != 0 and input[2] != 0
  212. assert (
  213. ndim == 3
  214. and input[0] != 0
  215. and valid_dims
  216. or (ndim == 4 and valid_dims and input[3] != 0)
  217. )
  218. assert kW // 2 >= padW and kH // 2 >= padH
  219. assert outputWidth >= 1 and outputHeight >= 1
  220. def max_pool2d(
  221. input: List[int],
  222. kernel_size: List[int],
  223. stride: List[int],
  224. padding: List[int],
  225. dilation: List[int],
  226. ceil_mode: bool,
  227. ):
  228. assert (
  229. len(kernel_size) == 1 or len(kernel_size) == 2
  230. ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
  231. kH = kernel_size[0]
  232. kW = kH if len(kernel_size) == 1 else kernel_size[1]
  233. assert (
  234. len(stride) == 0 or len(stride) == 1 or len(stride) == 2
  235. ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
  236. dH = kH if len(stride) == 0 else stride[0]
  237. if len(stride) == 0:
  238. dW = kW
  239. elif len(stride) == 1:
  240. dW = dH
  241. else:
  242. dW = stride[1]
  243. assert (
  244. len(padding) == 1 or len(padding) == 2
  245. ), "max_pool2d: padding must be either be a single int, or a tuple of two ints"
  246. padH = padding[0]
  247. padW = padH if len(padding) == 1 else padding[1]
  248. assert (
  249. len(dilation) == 1 or len(dilation) == 2
  250. ), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
  251. dilationH = dilation[0]
  252. dilationW = dilationH if len(dilation) == 1 else dilation[1]
  253. assert len(input) == 3 or len(input) == 4
  254. nbatch = input[-4] if len(input) == 4 else 1
  255. nInputPlane = input[-3]
  256. inputHeight = input[-2]
  257. inputWidth = input[-1]
  258. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
  259. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
  260. pool2d_shape_check(
  261. input,
  262. kH,
  263. kW,
  264. dH,
  265. dW,
  266. padH,
  267. padW,
  268. dilationH,
  269. dilationW,
  270. nInputPlane,
  271. inputHeight,
  272. inputWidth,
  273. outputHeight,
  274. outputWidth,
  275. )
  276. if len(input) == 3:
  277. return [nInputPlane, outputHeight, outputWidth]
  278. else:
  279. return [nbatch, nInputPlane, outputHeight, outputWidth]
  280. def max_pool2d_with_indices(
  281. input: List[int],
  282. kernel_size: List[int],
  283. stride: List[int],
  284. padding: List[int],
  285. dilation: List[int],
  286. ceil_mode: bool,
  287. ):
  288. out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  289. return (out, out)
  290. def upsample_nearest2d(
  291. input: List[int],
  292. output_size: Optional[List[int]],
  293. scale_factors: Optional[List[float]],
  294. ):
  295. out: List[int] = []
  296. out.append(input[0])
  297. out.append(input[1])
  298. if (scale_factors is None and output_size is None):
  299. assert 0, "Either output_size or scale_factors must be presented"
  300. if output_size is not None:
  301. assert (
  302. scale_factors is None
  303. ), "Must specify exactly one of output_size and scale_factors"
  304. assert len(output_size) == 2
  305. out.append(output_size[0])
  306. out.append(output_size[1])
  307. if scale_factors is not None:
  308. assert (
  309. output_size is None
  310. ), "Must specify exactly one of output_size and scale_factors"
  311. assert len(scale_factors) == 2
  312. out.append(int(input[2] * scale_factors[0]))
  313. out.append(int(input[3] * scale_factors[1]))
  314. return out
  315. def mm(self: List[int], mat2: List[int]):
  316. assert len(self) == 2, "self must be a matrix"
  317. assert len(mat2) == 2, "mat2 must be a matrix"
  318. assert self[1] == mat2[0]
  319. return [self[0], mat2[1]]
  320. def dot(self: List[int], tensor: List[int]):
  321. assert len(self) == 1 and len(tensor) == 1
  322. assert self[0] == tensor[0]
  323. out: List[int] = []
  324. return out
  325. def mv(self: List[int], vec: List[int]):
  326. assert len(self) == 2 and len(vec) == 1
  327. assert self[1] == vec[0]
  328. # TODO: return self
  329. return [self[0]]
  330. def unsqueeze(li: List[int], dim: int):
  331. dim = maybe_wrap_dim(dim, len(li) + 1)
  332. out = _copy(li)
  333. out.insert(dim, 1)
  334. return out
  335. def squeeze_nodim(li: List[int]):
  336. out: List[int] = []
  337. for i in range(len(li)):
  338. if li[i] != 1:
  339. out.append(li[i])
  340. return out
  341. def squeeze(li: List[int], dim: int):
  342. out: List[int] = []
  343. wrapped_dim = maybe_wrap_dim(dim, len(li))
  344. for i in range(len(li)):
  345. if i == wrapped_dim:
  346. if li[i] != 1:
  347. out.append(li[i])
  348. else:
  349. out.append(li[i])
  350. return out
  351. def index_select(self: List[int], dim: int, index: List[int]):
  352. dim = maybe_wrap_dim(dim, len(self))
  353. numel = multiply_integers(index)
  354. assert len(index) <= 1
  355. assert dim == 0 or dim < len(self)
  356. result_size: List[int] = []
  357. for i in range(len(self)):
  358. if dim == i:
  359. result_size.append(numel)
  360. else:
  361. result_size.append(self[i])
  362. return result_size
  363. def embedding(
  364. weight: List[int],
  365. indices: List[int],
  366. padding_idx: int = -1,
  367. scale_grad_by_freq: bool = False,
  368. sparse: bool = False,
  369. ):
  370. assert len(weight) == 2
  371. if len(indices) == 1:
  372. return index_select(weight, 0, indices)
  373. size = _copy(indices)
  374. size.append(weight[1])
  375. return size
  376. def max_int():
  377. return 9223372036854775807
  378. def slice(
  379. self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
  380. ):
  381. ndim = len(self)
  382. assert ndim != 0
  383. dim = maybe_wrap_dim(dim, ndim)
  384. start_val = start if start is not None else 0
  385. end_val = end if end is not None else max_int()
  386. assert step > 0
  387. if start_val == max_int():
  388. start_val = 0
  389. if start_val < 0:
  390. start_val += self[dim]
  391. if end_val < 0:
  392. end_val += self[dim]
  393. if start_val < 0:
  394. start_val = 0
  395. elif start_val > self[dim]:
  396. start_val = self[dim]
  397. if end_val < start_val:
  398. end_val = start_val
  399. elif end_val >= self[dim]:
  400. end_val = self[dim]
  401. slice_len = end_val - start_val
  402. out = _copy(self)
  403. out[dim] = (slice_len + step - 1) // step
  404. return out
  405. def check_cat_no_zero_dim(tensors: List[List[int]]):
  406. for tensor in tensors:
  407. assert len(tensor) > 0
  408. def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
  409. out_dim: Optional[int] = None
  410. for size in tensor_sizes:
  411. if not (len(size) == 1 and size[0] == 0):
  412. if out_dim is None:
  413. out_dim = maybe_wrap_dim(dim, len(size))
  414. if out_dim is None:
  415. out_dim = dim
  416. return out_dim
  417. def should_skip(tensor: List[int]):
  418. return numel(tensor) == 0 and len(tensor) == 1
  419. def check_cat_shape_except_dim(
  420. first: List[int], second: List[int], dimension: int, index: int
  421. ):
  422. first_dims = len(first)
  423. second_dims = len(second)
  424. assert first_dims == second_dims, "Tensors must have same number of dimensions"
  425. for dim in range(0, first_dims):
  426. if dim != dimension:
  427. assert (
  428. first[dim] == second[dim]
  429. ), "Sizes of tensors must match except in dimension"
  430. def cat(tensors: List[List[int]], dim: int):
  431. check_cat_no_zero_dim(tensors)
  432. dim = legacy_cat_wrap_dim(dim, tensors)
  433. assert len(tensors) > 0
  434. not_skipped_tensor: Optional[List[int]] = None
  435. for tensor in tensors:
  436. if not should_skip(tensor):
  437. not_skipped_tensor = tensor
  438. if not_skipped_tensor is None:
  439. return [0]
  440. cat_dim_size = 0
  441. for i in range(len(tensors)):
  442. tensor = tensors[i]
  443. if not should_skip(tensor):
  444. check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
  445. cat_dim_size = cat_dim_size + tensor[dim]
  446. result_size = _copy(not_skipped_tensor)
  447. result_size[dim] = cat_dim_size
  448. return result_size
  449. def select(self: List[int], dim: int, index: int):
  450. ndim = len(self)
  451. assert ndim != 0
  452. dim = maybe_wrap_dim(dim, ndim)
  453. size = self[dim]
  454. assert not (index < -size or index >= size)
  455. if index < 0:
  456. index += size
  457. out: List[int] = []
  458. for i in range(ndim):
  459. if i != dim:
  460. out.append(self[i])
  461. return out
  462. def matmul(tensor1: List[int], tensor2: List[int]):
  463. dim_tensor1 = len(tensor1)
  464. dim_tensor2 = len(tensor2)
  465. if dim_tensor1 == 1 and dim_tensor2 == 1:
  466. return dot(tensor1, tensor2)
  467. elif dim_tensor1 == 2 and dim_tensor2 == 1:
  468. return mv(tensor1, tensor2)
  469. elif dim_tensor1 == 1 and dim_tensor2 == 2:
  470. return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
  471. elif dim_tensor1 == 2 and dim_tensor2 == 2:
  472. return mm(tensor1, tensor2)
  473. elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
  474. # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
  475. # we track m1 vs m2 separately even though they must match for nicer error messages
  476. n = tensor1[-2] if dim_tensor1 > 1 else 1
  477. m1 = tensor1[-1]
  478. batch_tensor1: List[int] = []
  479. # TODO: handling of slice
  480. for i in range(dim_tensor1 - 2):
  481. batch_tensor1.append(tensor1[i])
  482. m2 = tensor2[-1] if dim_tensor2 > 1 else 1
  483. p = tensor2[-1]
  484. batch_tensor2: List[int] = []
  485. # TODO: handling of slice
  486. for i in range(dim_tensor2 - 2):
  487. batch_tensor2.append(tensor2[i])
  488. # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
  489. expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
  490. # todo: copy ?
  491. output_shape = expand_batch_portion
  492. if dim_tensor1 > 1:
  493. output_shape.append(n)
  494. if dim_tensor2 > 1:
  495. output_shape.append(p)
  496. return output_shape
  497. else:
  498. assert False, "both arguments to matmul need to be at least 1D"
  499. def t(self: List[int]):
  500. assert len(self) <= 2
  501. self_len = len(self)
  502. if self_len == 0:
  503. out: List[int] = []
  504. return out
  505. elif self_len == 1:
  506. return [self[0]]
  507. else:
  508. return [self[1], self[0]]
  509. def transpose(self: List[int], dim0: int, dim1: int):
  510. ndims = len(self)
  511. dim0 = maybe_wrap_dim(dim0, ndims)
  512. dim1 = maybe_wrap_dim(dim1, ndims)
  513. if dim0 == dim1:
  514. return _copy(self)
  515. out: List[int] = []
  516. for i in range(ndims):
  517. if i == dim0:
  518. out.append(self[dim1])
  519. elif i == dim1:
  520. out.append(self[dim0])
  521. else:
  522. out.append(self[i])
  523. return out
  524. def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
  525. out = matmul(input, t(weight))
  526. if bias is not None:
  527. assert broadcast(bias, out) == out
  528. return out
  529. def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
  530. return broadcast(self, mm(mat1, mat2))
  531. def check_non_negative(array: List[int]) -> bool:
  532. # TODO: look into rewriting with early return and getting loop unrolling to fire
  533. non_negative = False
  534. for val in array:
  535. if val < 0:
  536. non_negative = True
  537. return non_negative
  538. def check_shape_forward(
  539. input: List[int],
  540. weight_sizes: List[int],
  541. bias: Optional[List[int]],
  542. stride: List[int],
  543. padding: List[int],
  544. dilation: List[int],
  545. groups: int,
  546. ):
  547. k = len(input)
  548. weight_dim = len(weight_sizes)
  549. # TODO: assertions could be expanded with the error messages
  550. assert not check_non_negative(padding)
  551. assert not check_non_negative(stride)
  552. assert weight_dim == k
  553. assert weight_sizes[0] >= groups
  554. assert (weight_sizes[0] % groups) == 0
  555. # only handling not transposed
  556. assert input[1] == weight_sizes[1] * groups
  557. assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0])
  558. for i in range(2, k):
  559. assert (input[i] + 2 * padding[i - 2]) >= (
  560. dilation[i - 2] * (weight_sizes[i] - 1) + 1
  561. )
  562. # this is not handling transposed convolution yet
  563. def conv_output_size(
  564. input_size: List[int],
  565. weight_size: List[int],
  566. bias: Optional[List[int]],
  567. stride: List[int],
  568. padding: List[int],
  569. dilation: List[int],
  570. groups: int,
  571. ):
  572. check_shape_forward(
  573. input_size, weight_size, bias, stride, padding, dilation, groups
  574. )
  575. has_dilation = len(dilation) > 0
  576. dim = len(input_size)
  577. output_size: List[int] = []
  578. input_batch_size_dim = 0
  579. weight_output_channels_dim = 0
  580. output_size.append(input_size[input_batch_size_dim])
  581. output_size.append(weight_size[weight_output_channels_dim])
  582. for d in range(2, dim):
  583. dilation_ = dilation[d - 2] if has_dilation else 1
  584. kernel = dilation_ * (weight_size[d] - 1) + 1
  585. output_size.append(
  586. (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
  587. )
  588. return output_size
  589. def conv1d(
  590. input: List[int],
  591. weight: List[int],
  592. bias: Optional[List[int]],
  593. stride: List[int],
  594. padding: List[int],
  595. dilation: List[int],
  596. groups: int,
  597. ):
  598. assert len(weight) == 3
  599. assert len(input) == 3
  600. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  601. def conv2d(
  602. input: List[int],
  603. weight: List[int],
  604. bias: Optional[List[int]],
  605. stride: List[int],
  606. padding: List[int],
  607. dilation: List[int],
  608. groups: int,
  609. ):
  610. assert len(weight) == 4
  611. assert len(input) == 4
  612. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  613. def conv_backwards(grad_output: List[int], input:List[int], weight:List[int], biases:Optional[List[int]]):
  614. # Bias gradient is always generated regardess of if biases is supplied
  615. return _copy(input), _copy(weight), [grad_output[1]]
  616. def conv_transpose2d_input(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: Optional[List[int]] = None, padding: Optional[List[int]] = None, output_padding: Optional[List[int]] = None, groups: int = 1, dilation: Optional[List[int]] = None) -> List[int]:
  617. if stride is None:
  618. stride = [1, 1]
  619. if padding is None:
  620. padding = [0, 0]
  621. if output_padding is None:
  622. output_padding = [0, 0]
  623. if dilation is None:
  624. dilation = [1, 1]
  625. has_dilation = len(dilation) > 0
  626. dim = len(input)
  627. output_size: List[int] = []
  628. input_batch_size_dim = 0
  629. weight_output_channels_dim = 1
  630. output_size.append(input[input_batch_size_dim])
  631. output_size.append(weight[weight_output_channels_dim])
  632. for d in range(2, dim):
  633. dilation_ = dilation[d - 2] if has_dilation else 1
  634. kernel = dilation_ * (weight[d] - 1)
  635. output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1)
  636. return output_size
  637. def conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
  638. has_dilation = len(dilation) > 0
  639. dim = len(input)
  640. output_size: List[int] = []
  641. input_batch_size_dim = 0
  642. weight_output_channels_dim = 1 if transposed else 0
  643. output_size.append(input[input_batch_size_dim])
  644. output_size.append(weight[weight_output_channels_dim])
  645. for d in range(2, dim):
  646. dilation_ = dilation[d - 2] if has_dilation else 1
  647. if transposed:
  648. kernel = dilation_ * (weight[d] - 1)
  649. output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1)
  650. else:
  651. kernel = dilation_ * (weight[d] - 1) + 1
  652. output_size.append((input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1)
  653. return output_size
  654. def batch_norm(
  655. input: List[int],
  656. weight: Optional[List[int]],
  657. bias: Optional[List[int]],
  658. running_mean: Optional[List[int]],
  659. running_var: Optional[List[int]],
  660. training: bool,
  661. momentum: float,
  662. eps: float,
  663. cudnn_enabled: bool,
  664. ):
  665. out: List[int] = []
  666. for elem in input:
  667. out.append(elem)
  668. return out
  669. def conv3d(
  670. input: List[int],
  671. weight: List[int],
  672. bias: Optional[List[int]],
  673. stride: List[int],
  674. padding: List[int],
  675. dilation: List[int],
  676. groups: int,
  677. ):
  678. assert len(weight) == 5
  679. assert len(input) == 5
  680. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  681. def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
  682. if dim_post_expr <= 0:
  683. assert wrap_scalar
  684. dim_post_expr = 1
  685. min = -dim_post_expr
  686. max = dim_post_expr - 1
  687. assert not (dim < min or dim > max)
  688. if dim < 0:
  689. dim += dim_post_expr
  690. return dim
  691. def zero_dim_tensor(input: Any):
  692. out: List[int] = []
  693. return out
  694. def multiply_integers(li: List[int]):
  695. out = 1
  696. for elem in li:
  697. out = out * elem
  698. return out
  699. def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
  700. assert end >= 0
  701. return [int(math.ceil(end))]
  702. def arange_start(
  703. start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
  704. ):
  705. assert end >= 0
  706. assert end >= start
  707. return [int(math.ceil(end - start))]
  708. def arange_start_step(
  709. start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
  710. ):
  711. assert step != 0
  712. if step < 0:
  713. assert start >= end
  714. else:
  715. assert end >= start
  716. return [int(math.ceil((end - start) / step))]
  717. def permute(input: List[int], dims: List[int]):
  718. assert len(input) == len(dims)
  719. ndim = len(dims)
  720. seen_dims: List[int] = []
  721. newSizes: List[int] = []
  722. for i in range(ndim):
  723. dim = maybe_wrap_dim(dims[i], ndim)
  724. seen_dims.append(dim)
  725. newSizes.append(input[dim])
  726. for i in range(1, ndim):
  727. for j in range(i):
  728. assert seen_dims[i] != seen_dims[j]
  729. return newSizes
  730. def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
  731. self_dim = len(self)
  732. if self_dim <= 1:
  733. return self
  734. normalized_src : List[int] = []
  735. normalized_dst : List[int] = []
  736. for i in range(len(source)):
  737. normalized_src.append(maybe_wrap_dim(source[i], self_dim))
  738. normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
  739. order = [-1 for i in range(self_dim)]
  740. src_dims = [i for i in range(self_dim)]
  741. dst_dims = [i for i in range(self_dim)]
  742. for i in range(len(source)):
  743. order[normalized_dst[i]] = normalized_src[i]
  744. src_dims[normalized_src[i]] = -1
  745. dst_dims[normalized_dst[i]] = -1
  746. source_dims : List[int] = []
  747. destination_dims : List[int] = []
  748. for ele in src_dims:
  749. if ele != -1:
  750. source_dims.append(ele)
  751. for ele in dst_dims:
  752. if ele != -1:
  753. destination_dims.append(ele)
  754. rest_dim = self_dim - len(source)
  755. for i in range(rest_dim):
  756. order[destination_dims[i]] = source_dims[i]
  757. return permute(self, order)
  758. def flatten(input: List[int], start_dim: int, end_dim: int):
  759. start_dim = maybe_wrap_dim(start_dim, len(input))
  760. end_dim = maybe_wrap_dim(end_dim, len(input))
  761. assert start_dim <= end_dim
  762. if len(input) == 0:
  763. return [1]
  764. if start_dim == end_dim:
  765. # TODO: return self
  766. out: List[int] = []
  767. for elem in input:
  768. out.append(elem)
  769. return out
  770. slice_numel = 1
  771. for i in range(start_dim, end_dim + 1):
  772. slice_numel *= input[i]
  773. # TODO: use slicing when slice optimization has landed
  774. # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
  775. shape: List[int] = []
  776. for i in range(start_dim):
  777. shape.append(input[i])
  778. shape.append(slice_numel)
  779. for i in range(end_dim + 1, len(input)):
  780. shape.append(input[i])
  781. return shape
  782. def nonzero_lower_bound(input: List[int]):
  783. return [0, len(input)]
  784. def nonzero_upper_bound(input: List[int]):
  785. return [numel(input), len(input)]
  786. def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
  787. dim = maybe_wrap_dim(dim, len(self))
  788. out: List[int] = []
  789. for i, self_dim in enumerate(self):
  790. if i == dim:
  791. if keepdim:
  792. out.append(1)
  793. else:
  794. out.append(self_dim)
  795. return out
  796. def argmax(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
  797. if dim is None:
  798. return []
  799. return _reduce_along_dim(self, dim, keepdim)
  800. def bmm(self: List[int], mat2: List[int]) -> List[int]:
  801. assert len(self) == 3, "bmm only supports 3D tensors"
  802. assert len(mat2) == 3, "bmm only supports 3D tensors"
  803. assert self[0] == mat2[0], "mismatching batch dimension"
  804. assert self[2] == mat2[1], "mismatching contracting dimension"
  805. return [self[0], self[1], mat2[2]]
  806. def _shape_as_tensor(self: List[int]) -> List[int]:
  807. return [len(self)]
  808. def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
  809. if len(self) == 0:
  810. result: List[int] = []
  811. else:
  812. assert k <= self[dim], f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
  813. result = _copy(self)
  814. result[dim] = k
  815. return result, result
  816. def nll_loss_forward(self: List[int], target: List[int], weight: Optional[List[int]], reduction: int) -> Tuple[List[int], List[int]]:
  817. # This is taken shamelessly from the meta function in LossNLL.cpp
  818. self_dim = len(self)
  819. target_dim = len(target)
  820. assert 0 < self_dim <= 2
  821. assert target_dim <= 1
  822. no_batch_dim = self_dim == 1 and target_dim == 0
  823. assert no_batch_dim or (self[0] == target[0])
  824. n_classes = self[-1]
  825. scalar_shape: List[int] = []
  826. assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
  827. if reduction == 0 and self_dim == 2:
  828. reduction_shape = [self[0]]
  829. else:
  830. reduction_shape = scalar_shape
  831. return reduction_shape, scalar_shape
  832. def native_layer_norm(input: List[int], normalized_shape: List[int]) -> Tuple[List[int], List[int], List[int]]:
  833. reduction_shape: List[int] = []
  834. num_unreduced_dimensions = len(input) - len(normalized_shape)
  835. assert num_unreduced_dimensions >= 0
  836. for i in range(num_unreduced_dimensions):
  837. reduction_shape.append(input[i])
  838. for i in range(num_unreduced_dimensions, len(input)):
  839. reduction_shape.append(1)
  840. return _copy(input), reduction_shape, reduction_shape
  841. def native_batch_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool) -> Tuple[List[int], List[int], List[int]]:
  842. if training:
  843. _size = [input[1]]
  844. else:
  845. _size = [0]
  846. return _copy(input), _size, _size
  847. """
  848. Currently deferring the enabling of this, as part of the propoasal to suspend
  849. adding ops.
  850. There are currently cases in the test case where this is being called
  851. in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
  852. opinfo test). The behavoir of index is significantly dependent on the inputs.
  853. This could be an error with how we are matching up shape functions, or that this
  854. function needs to just implement everything.
  855. def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
  856. assert len(indices) <= len(self), "More indices than dimensions to index"
  857. broadcasted_shape: List[int] = []
  858. for index_tensor_shape in indices:
  859. if index_tensor_shape is not None:
  860. broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
  861. return broadcasted_shape
  862. """
  863. ScriptFn = torch._C.ScriptFunction
  864. shape_compute_graph_mapping : Dict[str, ScriptFn ] = {}
  865. bounded_compute_graph_mapping : Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
  866. script_func_map: Dict[Callable, ScriptFn] = {}
  867. def process_func(func: Callable):
  868. if func not in script_func_map:
  869. scripted_func = torch.jit.script(func)
  870. torch._C._jit_pass_inline(scripted_func.graph)
  871. for _ in range(2):
  872. torch._C._jit_pass_peephole(scripted_func.graph)
  873. torch._C._jit_pass_constant_propagation(scripted_func.graph)
  874. script_func_map[func] = scripted_func
  875. return script_func_map[func]
  876. def add_shape_compute_mapping(operator_schema: str, func: Callable):
  877. global shape_compute_graph_mapping
  878. shape_compute_graph_mapping[operator_schema] = process_func(func)
  879. def add_bounded_compute_mapping(operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable):
  880. # Adds a shape compute function for both upper and lower bounds
  881. fns = (process_func(lower_bound_func), process_func(upper_bound_func))
  882. bounded_compute_graph_mapping[operator_schema] = fns
  883. add_shape_compute_mapping("aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", unary)
  884. add_shape_compute_mapping("aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary)
  885. add_shape_compute_mapping("aten::dropout(Tensor input, float p, bool train) -> Tensor", unary)
  886. add_shape_compute_mapping("aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", adaptive_avg_pool2d)
  887. add_shape_compute_mapping("prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor)
  888. add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
  889. add_shape_compute_mapping("aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", unary)
  890. add_shape_compute_mapping("aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", unary)
  891. add_shape_compute_mapping("aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", arange_end)
  892. add_shape_compute_mapping("aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", arange_start)
  893. add_shape_compute_mapping("aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", arange_start_step)
  894. add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
  895. add_shape_compute_mapping("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze)
  896. add_shape_compute_mapping("aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze)
  897. add_shape_compute_mapping("aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", slice)
  898. add_shape_compute_mapping("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select)
  899. add_shape_compute_mapping("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select)
  900. add_shape_compute_mapping("aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
  901. "float eps=1e-05, bool cudnn_enable=True) -> Tensor", unary)
  902. add_shape_compute_mapping("aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary)
  903. add_shape_compute_mapping("aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", unary)
  904. add_shape_compute_mapping("aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", unary)
  905. add_shape_compute_mapping("aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", embedding)
  906. add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
  907. add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
  908. add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
  909. add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
  910. add_shape_compute_mapping("aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear)
  911. add_shape_compute_mapping("aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", max_pool2d)
  912. add_shape_compute_mapping("aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", max_pool2d_with_indices)
  913. add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
  914. add_shape_compute_mapping("aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose)
  915. add_shape_compute_mapping("aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", conv1d)
  916. add_shape_compute_mapping("aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", conv2d)
  917. add_shape_compute_mapping("aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", batch_norm)
  918. add_shape_compute_mapping("aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", conv3d)
  919. add_shape_compute_mapping("aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", conv_backwards)
  920. add_shape_compute_mapping("aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", conv_forwards)
  921. add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input)
  922. add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
  923. add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
  924. add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute)
  925. add_shape_compute_mapping("aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", movedim)
  926. add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
  927. add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand)
  928. add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused)
  929. add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", sum_mean_dim)
  930. add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", sum_mean_dim)
  931. add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim)
  932. add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
  933. add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
  934. add_shape_compute_mapping("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", addmm)
  935. add_shape_compute_mapping("aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", upsample_nearest2d)
  936. add_shape_compute_mapping("aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", unary)
  937. add_shape_compute_mapping("aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", unary)
  938. add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
  939. add_shape_compute_mapping("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", broadcast)
  940. add_shape_compute_mapping("aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax)
  941. add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
  942. add_shape_compute_mapping("aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor)
  943. add_shape_compute_mapping("aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", topk)
  944. add_shape_compute_mapping("aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", nll_loss_forward)
  945. add_shape_compute_mapping("aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", native_layer_norm)
  946. add_shape_compute_mapping("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
  947. add_shape_compute_mapping("aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
  948. add_shape_compute_mapping("aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
  949. # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
  950. # TODO: migrate over all of symbolic_shape_registry_util.cpp
  951. # These are duplicated here so that the functions will be serialiazed
  952. add_shape_compute_mapping("aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", broadcast_three)
  953. add_shape_compute_mapping("aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", broadcast_one_three)
  954. add_shape_compute_mapping("aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", broadcast_inplace)
  955. # quantized_conv_prepack TODO
  956. # Shape Compute Fn with upper and lower bounds
  957. add_bounded_compute_mapping("aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound)