_ops_refs.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from functools import partial
  3. import torch
  4. from .binary import (
  5. _apply_native_binary,
  6. NATIVE_BINARY_FNS,
  7. NATIVE_INPLACE_BINARY_FNS,
  8. )
  9. from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask
  10. from .passthrough import (
  11. _apply_pass_through_fn,
  12. PASSTHROUGH_FNS
  13. )
  14. from .reductions import (
  15. _apply_reduction,
  16. NATIVE_REDUCE_FNS,
  17. TORCH_REDUCE_FNS,
  18. TENSOR_REDUCE_FNS,
  19. )
  20. from .unary import (
  21. _apply_native_unary,
  22. NATIVE_UNARY_FNS,
  23. NATIVE_INPLACE_UNARY_FNS,
  24. )
  25. __all__ = [] # type: ignore[var-annotated]
  26. def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None):
  27. if len_args is not None and len_args != len(args):
  28. raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}")
  29. if len_kwargs is not None and len_kwargs != len(kwargs):
  30. raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}")
  31. class _MaskedContiguous(torch.autograd.Function):
  32. @staticmethod
  33. def forward(ctx, input):
  34. if not is_masked_tensor(input):
  35. raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
  36. if input.is_contiguous():
  37. return input
  38. data = input.get_data()
  39. mask = input.get_mask()
  40. return MaskedTensor(data.contiguous(), mask.contiguous())
  41. @staticmethod
  42. def backward(ctx, grad_output):
  43. return grad_output
  44. class _MaskedToDense(torch.autograd.Function):
  45. @staticmethod
  46. def forward(ctx, input):
  47. if not is_masked_tensor(input):
  48. raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
  49. if input.layout == torch.strided:
  50. return input
  51. ctx.layout = input.layout
  52. data = input.get_data()
  53. mask = input.get_mask()
  54. return MaskedTensor(data.to_dense(), mask.to_dense())
  55. @staticmethod
  56. def backward(ctx, grad_output):
  57. layout = ctx.layout
  58. if layout == torch.sparse_coo:
  59. return grad_output.to_sparse_coo()
  60. elif layout == torch.sparse_csr:
  61. return grad_output.to_sparse_csr()
  62. elif layout == torch.strided:
  63. return grad_output.to_dense()
  64. raise ValueError("to_dense: Unsupported input layout: ", layout)
  65. class _MaskedToSparse(torch.autograd.Function):
  66. @staticmethod
  67. def forward(ctx, input):
  68. if not is_masked_tensor(input):
  69. raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
  70. # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo
  71. if input.layout == torch.sparse_coo:
  72. return input
  73. data = input.get_data()
  74. mask = input.get_mask()
  75. sparse_mask = mask.to_sparse_coo().coalesce()
  76. sparse_data = data.sparse_mask(sparse_mask)
  77. return MaskedTensor(sparse_data, sparse_mask)
  78. @staticmethod
  79. def backward(ctx, grad_output):
  80. return grad_output.to_dense()
  81. class _MaskedToSparseCsr(torch.autograd.Function):
  82. @staticmethod
  83. def forward(ctx, input):
  84. if not is_masked_tensor(input):
  85. raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
  86. if input._masked_data.ndim != 2:
  87. raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}")
  88. if input.layout == torch.sparse_csr:
  89. return input
  90. data = input.get_data()
  91. mask = input.get_mask()
  92. sparse_mask = mask.to_sparse_csr()
  93. sparse_data = data.sparse_mask(sparse_mask)
  94. return MaskedTensor(sparse_data, sparse_mask)
  95. @staticmethod
  96. def backward(ctx, grad_output):
  97. return grad_output.to_dense()
  98. class _MaskedWhere(torch.autograd.Function):
  99. @staticmethod
  100. def forward(ctx, cond, self, other):
  101. ctx.mark_non_differentiable(cond)
  102. ctx.save_for_backward(cond)
  103. return torch.ops.aten.where(cond, self, other)
  104. @staticmethod
  105. def backward(ctx, grad_output):
  106. (cond,) = ctx.saved_tensors
  107. def masked_out_like(mt):
  108. return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool())
  109. return (
  110. None,
  111. torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)),
  112. torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output),
  113. )
  114. _MASKEDTENSOR_FUNCTION_TABLE = {}
  115. _function_fn_apply_map = {
  116. (tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction,
  117. }
  118. for fn_map_list, apply_fn in _function_fn_apply_map.items():
  119. for fn_map in fn_map_list:
  120. for fn in fn_map:
  121. _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn)
  122. def register_function_func(ops):
  123. """
  124. Used for registering a new __torch_function__ function to MaskedTensor
  125. Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
  126. The code to register a new function looks like:
  127. @register_function_func(list_of_ops)
  128. def foo(func, *args, **kwargs):
  129. <implementation>
  130. """
  131. def wrapper(func):
  132. for op in ops:
  133. _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
  134. return wrapper
  135. @register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
  136. def _general_function_reductions(func, *args, **kwargs):
  137. return _apply_reduction(func, *args, **kwargs)
  138. @register_function_func([torch.Tensor.where, torch.where])
  139. def _function_where(func, *args, **kwargs):
  140. _check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0)
  141. return _MaskedWhere.apply(*args)
  142. @register_function_func([torch.Tensor.contiguous])
  143. def _function_contiguous(func, *args, **kwargs):
  144. return _MaskedContiguous.apply(args[0])
  145. @register_function_func([torch.Tensor.to_dense])
  146. def _function_to_dense(func, *args, **kwargs):
  147. return _MaskedToDense.apply(args[0])
  148. @register_function_func([torch.Tensor.to_sparse])
  149. def _function_to_sparse(func, *args, **kwargs):
  150. return _MaskedToSparse.apply(args[0])
  151. @register_function_func([torch.Tensor.to_sparse_csr])
  152. def _function_to_sparse_csr(func, *args, **kwargs):
  153. return _MaskedToSparseCsr.apply(args[0])
  154. _MASKEDTENSOR_DISPATCH_TABLE = {}
  155. def register_dispatch_func(aten_ops):
  156. """
  157. Used for registering a new __torch_dispatch__ function to MaskedTensor
  158. Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
  159. The code to register a new function looks like:
  160. @register_dispatch_func(list_of_ops)
  161. def foo(func, *args, **kwargs):
  162. <implementation>
  163. """
  164. def wrapper(func):
  165. for aten_op in aten_ops:
  166. _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
  167. return wrapper
  168. @register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
  169. def _general_reduction(func, *args, **kwargs):
  170. return _apply_reduction(func, *args, **kwargs)
  171. @register_dispatch_func(PASSTHROUGH_FNS)
  172. def _general_passthrough(func, *args, **kwargs):
  173. return _apply_pass_through_fn(func, *args, **kwargs)
  174. @register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS)
  175. def _general_unary(func, *args, **kwargs):
  176. return _apply_native_unary(func, *args, **kwargs)
  177. @register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
  178. def _general_binary(func, *args, **kwargs):
  179. return _apply_native_binary(func, *args, **kwargs)
  180. @register_dispatch_func([torch.ops.aten.stride])
  181. def stride(func, *args, **kwargs):
  182. return None
  183. @register_dispatch_func([torch.ops.aten.sym_stride])
  184. def sym_stride(func, *args, **kwargs):
  185. return None
  186. @register_dispatch_func([torch.ops.prim.layout])
  187. def layout(func, *args, **kwargs):
  188. return _get_data(args[0]).layout
  189. @register_dispatch_func([torch.ops.aten.is_contiguous])
  190. def is_contiguous(func, *args, **kwargs):
  191. data = _get_data(args[0])
  192. if data.is_sparse:
  193. raise ValueError(
  194. "MaskedTensors with sparse data do not have is_contiguous"
  195. )
  196. return func(data, *args[1:], **kwargs)
  197. @register_dispatch_func([torch.ops.aten.is_strides_like_format])
  198. def is_strides_like_format(func, *args, **kwargs):
  199. data = _get_data(args[0])
  200. if data.is_sparse:
  201. raise ValueError(
  202. "MaskedTensors with sparse data do not have is_strides_like_format"
  203. )
  204. return func(data, *args[1:], **kwargs)
  205. @register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
  206. def is_non_overlapping_and_dense(func, *args, **kwargs):
  207. data = _get_data(args[0])
  208. if data.is_sparse:
  209. raise ValueError(
  210. "MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
  211. )
  212. return func(data, *args[1:], **kwargs)
  213. @register_dispatch_func([torch.ops.aten.contiguous])
  214. def contiguous(func, *args, **kwargs):
  215. if _get_data(args[0]).is_sparse:
  216. raise ValueError(
  217. "MaskedTensors with sparse data do not have contiguous"
  218. )
  219. return _MaskedContiguous.apply(args[0])
  220. @register_dispatch_func([torch.ops.aten.new_empty_strided])
  221. def new_empty_strided(func, *args, **kwargs):
  222. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3)
  223. data = _get_data(args[0])
  224. mask = _maybe_get_mask(args[0])
  225. if tuple(args[1]) != tuple(data.size()):
  226. raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()")
  227. if tuple(args[2]) != tuple(data.stride()):
  228. raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()")
  229. return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
  230. @register_dispatch_func([torch.ops.aten._local_scalar_dense])
  231. def _local_scalar_dense(func, *args, **kwargs):
  232. if not _maybe_get_mask(args[0]):
  233. raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor")
  234. return torch.ops.aten._local_scalar_dense(_get_data(args[0]))
  235. @register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone])
  236. def _apply_fn_on_data(func, *args, **kwargs):
  237. return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0]))
  238. @register_dispatch_func([torch.ops.aten._to_copy])
  239. def _to_copy(func, *args, **kwargs):
  240. new_data = func(_get_data(args[0]), *args[1:], **kwargs)
  241. return MaskedTensor(new_data, _maybe_get_mask(args[0]))
  242. @register_dispatch_func([torch.ops.aten._softmax])
  243. def _softmax(func, *args, **kwargs):
  244. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
  245. data = _get_data(args[0])
  246. mask = _maybe_get_mask(args[0])
  247. result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
  248. return MaskedTensor(result_data, mask)
  249. @register_dispatch_func([torch.ops.aten.ones_like])
  250. def ones_like(func, *args, **kwargs):
  251. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
  252. result_data = func(_get_data(args[0]), **kwargs)
  253. return MaskedTensor(result_data, _maybe_get_mask(args[0]))
  254. @register_dispatch_func([torch.ops.aten._softmax_backward_data])
  255. def _softmax_backward_data(func, *args, **kwargs):
  256. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
  257. grad, output, dim, input_dtype = args
  258. if is_masked_tensor(grad) and is_masked_tensor(output):
  259. if not _masks_match(grad, output):
  260. raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match")
  261. grad_data = _get_data(grad)
  262. new_grad_data = torch.ops.aten._masked_softmax_backward(
  263. grad_data,
  264. _get_data(output),
  265. ~_maybe_get_mask(grad),
  266. dim % grad_data.ndim,
  267. )
  268. res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
  269. return res
  270. else:
  271. raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors")
  272. @register_dispatch_func([torch.ops.aten.copy_])
  273. def copy_(func, *args, **kwargs):
  274. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
  275. if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])):
  276. raise ValueError("args[0] mask and args[1] mask must match but do not")
  277. func(_get_data(args[0]), _get_data(args[1]))
  278. return args[0]
  279. @register_dispatch_func([torch.ops.aten.where])
  280. def where(func, *args, **kwargs):
  281. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
  282. if not torch.is_tensor(args[0]):
  283. raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
  284. mx = args[1]
  285. my = args[2]
  286. if not is_masked_tensor(mx):
  287. mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool))
  288. if not is_masked_tensor(my):
  289. my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool))
  290. new_data = func(args[0], mx.get_data(), my.get_data())
  291. new_mask = func(args[0], mx.get_mask(), my.get_mask())
  292. return MaskedTensor(new_data, new_mask)
  293. @register_dispatch_func([torch.ops.aten.to_sparse])
  294. def to_sparse(func, *args, **kwargs):
  295. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
  296. if not torch.is_tensor(args[0]):
  297. raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
  298. mt = args[0]
  299. if not is_masked_tensor(mt):
  300. mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool))
  301. if mt.is_sparse_coo():
  302. return mt
  303. new_mask = func(_maybe_get_mask(args[0])).coalesce()
  304. new_data = _get_data(args[0]).sparse_mask(new_mask)
  305. return MaskedTensor(new_data, new_mask)
  306. @register_dispatch_func([torch.ops.aten.to_sparse_csr])
  307. def to_sparse_csr(func, *args, **kwargs):
  308. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
  309. if not torch.is_tensor(args[0]):
  310. raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
  311. mt = args[0]
  312. if not is_masked_tensor(mt):
  313. mt = MaskedTensor(mt, torch.ones_like(mt).bool())
  314. if mt.is_sparse_csr():
  315. return mt
  316. new_mask = func(_maybe_get_mask(args[0]))
  317. new_data = _get_data(args[0]).sparse_mask(new_mask)
  318. return MaskedTensor(new_data, new_mask)
  319. @register_dispatch_func([torch.ops.aten._to_dense])
  320. def _to_dense(func, *args, **kwargs):
  321. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
  322. if not torch.is_tensor(args[0]):
  323. raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
  324. mt = args[0]
  325. if not is_masked_tensor(mt):
  326. mt = MaskedTensor(mt, torch.ones_like(mt).bool())
  327. new_data = func(_get_data(args[0]))
  328. new_mask = func(_maybe_get_mask(args[0]))
  329. return MaskedTensor(new_data, new_mask)
  330. @register_dispatch_func([torch.ops.aten._indices])
  331. def _indices(func, *args, **kwargs):
  332. # Assumes data is sparse
  333. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
  334. data = _get_data(args[0]).indices()
  335. return MaskedTensor(data, torch.ones_like(data).bool())
  336. @register_dispatch_func([torch.ops.aten._values])
  337. def _values(func, *args, **kwargs):
  338. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
  339. data = _get_data(args[0]).values()
  340. return MaskedTensor(data, torch.ones_like(data).bool())
  341. @register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors])
  342. def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs):
  343. new_args = list(args)
  344. if is_masked_tensor(args[-1]):
  345. new_args[-1] = args[-1].get_data()
  346. if is_masked_tensor(args[-2]):
  347. new_args[-2] = args[-2].get_data()
  348. new_data = func(*new_args, **kwargs)
  349. new_args[-1] = torch.ones_like(new_args[-1])
  350. new_mask = func(*new_args, **kwargs).bool()
  351. return MaskedTensor(new_data, new_mask)
  352. @register_dispatch_func([torch.ops.aten.is_same_size])
  353. def is_same_size(func, *args, **kwargs):
  354. _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
  355. return _get_data(args[0]).is_same_size(_get_data(args[1]))