decomposition.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import functools
  2. import logging
  3. import math
  4. import numbers
  5. import torch
  6. import torch._decomp as decomp
  7. from torch import Tensor
  8. from torch._decomp import core_aten_decompositions, get_decompositions
  9. from torch._decomp.decompositions import pw_cast_for_opmath
  10. from torch.utils._mode_utils import no_dispatch
  11. from . import config, utils
  12. log = logging.getLogger(__name__)
  13. aten = torch.ops.aten
  14. inductor_decompositions = get_decompositions(
  15. [
  16. aten.arange,
  17. aten.bitwise_and_,
  18. aten.bitwise_or_,
  19. aten.clamp_min_,
  20. aten.flip,
  21. aten.lcm,
  22. aten.linalg_vector_norm,
  23. aten.sin_,
  24. aten.sqrt_,
  25. aten.std,
  26. aten.std_mean,
  27. aten._to_copy,
  28. aten.tril_indices,
  29. aten.triu_indices,
  30. aten.unsafe_split,
  31. ]
  32. )
  33. decompositions = {**core_aten_decompositions(), **inductor_decompositions}
  34. def register_decomposition(ops):
  35. for op in [ops] if callable(ops) else ops:
  36. if op in decompositions:
  37. log.warning(f"duplicate decomp: {ops}")
  38. return decomp.register_decomposition(ops, decompositions)
  39. @register_decomposition([aten.clamp])
  40. @pw_cast_for_opmath
  41. def clamp(x, min=None, max=None):
  42. if min is not None:
  43. x = x.clamp_min(min)
  44. if max is not None:
  45. x = x.clamp_max(max)
  46. return x
  47. # TorchInductor-only decomposition. It should not be taken to core.
  48. # See https://github.com/pytorch/torchdynamo/pull/1120
  49. @register_decomposition([aten.floor_divide.default])
  50. def floordiv(a, b):
  51. return aten.div.Tensor_mode(a, b, rounding_mode="floor")
  52. def get_alignment_size(x):
  53. if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
  54. return 8
  55. elif x.dtype == torch.float32 or x.dtype == torch.float:
  56. return 4
  57. else:
  58. return 0
  59. def check_device(a: Tensor, b: Tensor):
  60. return a.is_cuda and b.is_cuda
  61. def get_padded_length(x, alignment_size):
  62. if alignment_size == 0 or x % alignment_size == 0:
  63. return 0
  64. return int((x // alignment_size + 1) * alignment_size) - x
  65. def pad_dim(x, padded_length, dim):
  66. if padded_length == 0:
  67. return x
  68. pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
  69. return torch.cat([x, pad], dim=dim)
  70. @register_decomposition([aten.addmm])
  71. def addmm(input, mat1, mat2, *, beta=1, alpha=1):
  72. if (
  73. config.shape_padding
  74. and check_device(mat1, mat2)
  75. and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input)
  76. ):
  77. m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
  78. k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
  79. n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
  80. if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
  81. return pad_addmm(
  82. input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length
  83. )
  84. return NotImplemented # go directly to lowering
  85. def pad_addmm(input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
  86. # addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded
  87. if k_padded_length != 0:
  88. mat1 = pad_dim(mat1, k_padded_length, 1)
  89. mat2 = pad_dim(mat2, k_padded_length, 0)
  90. elif n_padded_length != 0:
  91. mat2 = pad_dim(mat2, n_padded_length, 1)
  92. elif m_padded_length != 0:
  93. mat1 = pad_dim(mat1, m_padded_length, 0)
  94. if input is not None and k_padded_length == 0:
  95. if n_padded_length != 0:
  96. if input.dim() == 2:
  97. input = pad_dim(input, n_padded_length, 1)
  98. elif input.dim() == 1:
  99. input = pad_dim(input, n_padded_length, 0)
  100. elif m_padded_length != 0 and input.dim() == 2:
  101. input = pad_dim(input, m_padded_length, 0)
  102. if k_padded_length != 0:
  103. return torch.ops.aten.addmm(input, mat1, mat2)
  104. elif n_padded_length != 0:
  105. return torch.ops.aten.addmm(input, mat1, mat2)[:, :-n_padded_length]
  106. else:
  107. return torch.ops.aten.addmm(input, mat1, mat2)[:-m_padded_length, :]
  108. def should_pad_bench(mat1, mat2, op, input=None):
  109. assert utils.has_triton()
  110. from triton.testing import do_bench
  111. with no_dispatch():
  112. if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
  113. m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
  114. k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
  115. n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
  116. elif op is torch.ops.aten.bmm:
  117. m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
  118. k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
  119. n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
  120. else:
  121. return False
  122. if m_padded_length == k_padded_length == n_padded_length == 0:
  123. return False
  124. mat1 = torch.randn_like(mat1)
  125. mat2 = torch.randn_like(mat2)
  126. warmup = 5
  127. rep = 100
  128. if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
  129. ori_time = do_bench(
  130. lambda: op(mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
  131. )[0]
  132. else:
  133. if input is not None:
  134. input = torch.randn_like(input)
  135. ori_time = do_bench(
  136. lambda: op(input, mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
  137. )[0]
  138. mat1_pad = torch.randn_like(mat1)
  139. mat2_pad = torch.randn_like(mat2)
  140. if op is torch.ops.aten.addmm:
  141. input_pad = None
  142. if input is not None and input.is_cuda:
  143. input_pad = torch.randn_like(input)
  144. pad_time = do_bench(
  145. lambda: pad_addmm(
  146. input_pad,
  147. mat1_pad,
  148. mat2_pad,
  149. m_padded_length,
  150. k_padded_length,
  151. n_padded_length,
  152. ),
  153. warmup=warmup,
  154. rep=rep,
  155. fast_flush=True,
  156. )[0]
  157. elif op is torch.ops.aten.mm:
  158. pad_time = do_bench(
  159. lambda: pad_mm(
  160. mat1_pad,
  161. mat2_pad,
  162. m_padded_length,
  163. k_padded_length,
  164. n_padded_length,
  165. ),
  166. warmup=warmup,
  167. rep=rep,
  168. fast_flush=True,
  169. )[0]
  170. else:
  171. pad_time = do_bench(
  172. lambda: pad_bmm(
  173. mat1_pad,
  174. mat2_pad,
  175. m_padded_length,
  176. k_padded_length,
  177. n_padded_length,
  178. ),
  179. warmup=warmup,
  180. rep=rep,
  181. fast_flush=True,
  182. )[0]
  183. # Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
  184. # tradeoff between performance improvement from shape padding and overhead from addtional memory ops
  185. # TODO: Build a learned model which would be better than this heuristic
  186. return ori_time > pad_time * 1.1
  187. @register_decomposition([aten.mm])
  188. def mm_decomp(mat1, mat2):
  189. if (
  190. config.shape_padding
  191. and check_device(mat1, mat2)
  192. and should_pad_bench(mat1, mat2, torch.ops.aten.mm)
  193. ):
  194. m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
  195. k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
  196. n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
  197. if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
  198. return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
  199. return NotImplemented # go directly to lowering
  200. def pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
  201. # mm_decomp will go through pad_mm multiple times if multiple dimensions are needed to be padded
  202. if k_padded_length != 0:
  203. mat1 = pad_dim(mat1, k_padded_length, 1)
  204. mat2 = pad_dim(mat2, k_padded_length, 0)
  205. return torch.ops.aten.mm(mat1, mat2)
  206. elif n_padded_length != 0:
  207. mat2 = pad_dim(mat2, n_padded_length, 1)
  208. return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
  209. else:
  210. mat1 = pad_dim(mat1, m_padded_length, 0)
  211. return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
  212. @register_decomposition([aten.bmm])
  213. def bmm_decomp(mat1, mat2):
  214. if (
  215. config.shape_padding
  216. and check_device(mat1, mat2)
  217. and should_pad_bench(mat1, mat2, torch.ops.aten.bmm)
  218. ):
  219. m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
  220. k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
  221. n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
  222. if k_padded_length != 0 or n_padded_length != 0 or m_padded_length != 0:
  223. pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
  224. return NotImplemented # go directly to lowering
  225. def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
  226. # bmm_decomp will go through pad_bmm multiple times if multiple dimensions are needed to be padded
  227. if k_padded_length != 0:
  228. mat1 = pad_dim(mat1, k_padded_length, 2)
  229. mat2 = pad_dim(mat2, k_padded_length, 1)
  230. return torch.ops.aten.bmm(mat1, mat2)
  231. elif n_padded_length != 0:
  232. mat2 = pad_dim(mat2, n_padded_length, 2)
  233. return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
  234. else:
  235. mat1 = pad_dim(mat1, m_padded_length, 1)
  236. return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
  237. @register_decomposition([aten.convolution_backward])
  238. def convolution_backward(
  239. grad_output,
  240. input,
  241. weight,
  242. bias_sizes,
  243. stride,
  244. padding,
  245. dilation,
  246. transposed,
  247. output_padding,
  248. groups,
  249. output_mask,
  250. ):
  251. if not output_mask[2] or grad_output.device.type != "cuda":
  252. return NotImplemented
  253. grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
  254. grad_inp, grad_weight, _ = aten.convolution_backward(
  255. grad_output,
  256. input,
  257. weight,
  258. bias_sizes,
  259. stride,
  260. padding,
  261. dilation,
  262. transposed,
  263. output_padding,
  264. groups,
  265. [output_mask[0], output_mask[1], False],
  266. )
  267. return (grad_inp, grad_weight, grad_bias)
  268. @register_decomposition([aten.log2])
  269. def log2(x):
  270. return torch.log(x) * (1.0 / math.log(2.0))
  271. @register_decomposition([aten.round.decimals])
  272. def round_dec(x, decimals=0):
  273. ten_pow_decimals = 10.0**decimals
  274. return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
  275. @register_decomposition([aten.all.default])
  276. def all(input):
  277. return torch.logical_not(torch.any(torch.logical_not(input)))
  278. @register_decomposition([aten.all.dim])
  279. def all_dim(input, dim, keeepdim=False):
  280. return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim))
  281. # NB: this decomposition is not stride accurate, do not put it in the main
  282. # library
  283. @register_decomposition(aten.copy)
  284. def copy(self, src, non_blocking=False):
  285. intermediate = src.to(self, non_blocking)
  286. if self.size() != intermediate.size():
  287. return aten.expand_copy.default(intermediate, self.size())
  288. else:
  289. return intermediate
  290. @register_decomposition([aten.baddbmm])
  291. def baddbmm(self, batch1, batch2, beta=1, alpha=1):
  292. result = torch.bmm(batch1, batch2)
  293. if not isinstance(alpha, numbers.Number) or alpha != 1:
  294. result = result * alpha
  295. if not isinstance(beta, numbers.Number) or beta != 1:
  296. self = self * beta
  297. return self + result
  298. @register_decomposition([aten.conj_physical])
  299. def conj_physical(self):
  300. assert not self.is_complex(), "TODO: implement this"
  301. return self
  302. @register_decomposition([aten.lift, aten.detach_])
  303. def lift(self):
  304. return self
  305. @register_decomposition([aten.bernoulli.default])
  306. def bernoulli(self, *, generator=None):
  307. assert generator is None
  308. return torch.rand_like(self, dtype=torch.float32) < self
  309. """
  310. Some decomps result in differences from eager related to randomness.
  311. We put these decomps in a separate table `extra_random_decomps` to allow
  312. turning them on and off via `config.fallback_random`.
  313. """
  314. extra_random_decomps = get_decompositions(
  315. [
  316. aten.native_dropout,
  317. aten.cauchy,
  318. aten.cauchy_,
  319. aten.exponential,
  320. aten.exponential_,
  321. aten.geometric,
  322. aten.geometric_,
  323. aten.log_normal,
  324. aten.log_normal_,
  325. aten.uniform_,
  326. ]
  327. )
  328. register_extra_random_decomp = functools.partial(
  329. decomp.register_decomposition, registry=extra_random_decomps
  330. )
  331. @register_extra_random_decomp([aten.bernoulli_])
  332. def bernoulli_(self, p=0.5):
  333. return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
  334. @register_extra_random_decomp([aten.bernoulli.p])
  335. def bernoulli_p(self, p=0.5, *, generator=None):
  336. assert generator is None
  337. return torch.rand_like(self, dtype=torch.float32) < p
  338. @functools.lru_cache(None)
  339. def fast_random_decomps():
  340. return {**decompositions, **extra_random_decomps}
  341. def select_decomp_table():
  342. """decomps can change based on config"""
  343. if config.fallback_random:
  344. return decompositions
  345. return fast_random_decomps()