fake_utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import warnings
  2. from typing import Callable, Union
  3. import torch
  4. import torch.utils._pytree as pytree
  5. from torch._ops import OpOverload
  6. from torch._subclasses.fake_tensor import (
  7. FakeTensorMode,
  8. tree_flatten_only,
  9. UnsupportedFakeTensorException,
  10. )
  11. from torch.utils._python_dispatch import TorchDispatchMode
  12. from torch.utils._pytree import tree_flatten
  13. aten = torch._ops.ops.aten
  14. def outputs_alias_inputs(outputs, inputs):
  15. input_storages = {
  16. inp._typed_storage()._cdata
  17. for inp in tree_flatten_only(torch.Tensor, inputs)
  18. if torch._C._has_storage(inp)
  19. }
  20. return any(
  21. torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
  22. for out in tree_flatten_only(torch.Tensor, outputs)
  23. )
  24. def outputs_are_inputs(outputs, inputs):
  25. input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
  26. return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
  27. def output_alias_each_other(outputs):
  28. storages = set()
  29. for out in tree_flatten_only(torch.Tensor, outputs):
  30. if not torch._C._has_storage(out):
  31. continue
  32. stor = out._typed_storage()._cdata
  33. if stor in storages:
  34. return True
  35. storages.add(stor)
  36. return False
  37. class CrossRefFakeMode(TorchDispatchMode):
  38. def __init__(
  39. self,
  40. ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
  41. *,
  42. check_strides=True,
  43. check_aliasing=True,
  44. ):
  45. self.ignore_op_fn = (
  46. ignore_op_fn if ignore_op_fn is not None else lambda fn: False
  47. )
  48. self.check_strides = check_strides
  49. self.check_aliasing = check_aliasing
  50. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  51. kwargs = kwargs or {}
  52. fake_r = None
  53. # empty_like excluded for now due to sparse complex
  54. # aten._to_dense.default this one is getting called with csc
  55. if (
  56. func
  57. not in (
  58. aten.lift_fresh.default,
  59. aten.lift_fresh_copy.default,
  60. aten.set_.source_Storage_storage_offset,
  61. )
  62. and not self.ignore_op_fn(func)
  63. and torch.Tag.dynamic_output_shape not in func.tags # type: ignore[attr-defined]
  64. and torch.Tag.inplace_view not in func.tags # type: ignore[attr-defined]
  65. and torch.Tag.data_dependent_output not in func.tags # type: ignore[attr-defined]
  66. ):
  67. try:
  68. with FakeTensorMode() as fake_mode:
  69. fake_args, fake_kwargs = pytree.tree_map_only(
  70. torch.Tensor, fake_mode.from_tensor, (args, kwargs)
  71. )
  72. with warnings.catch_warnings():
  73. fake_r = func(*fake_args, **fake_kwargs)
  74. except UnsupportedFakeTensorException:
  75. pass
  76. r = func(*args, **kwargs)
  77. if fake_r is not None:
  78. r_flat, _ = tree_flatten(r)
  79. f_flat, _ = tree_flatten(fake_r)
  80. assert len(r_flat) == len(
  81. r_flat
  82. ), f"Mismatch {len(r_flat)} != {len(r_flat)} on {func}"
  83. if self.check_aliasing:
  84. r_aliasing = outputs_alias_inputs(r, (args, kwargs))
  85. f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
  86. assert (
  87. r_aliasing == f_aliasing
  88. ), f"Mismatch on {func}: {r_aliasing} != {f_aliasing}"
  89. r_identity_eq = outputs_are_inputs(r, (args, kwargs))
  90. f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
  91. assert (
  92. r_identity_eq == f_identity_eq
  93. ), f"Mismatch on {func}: {r_identity_eq} != {f_identity_eq}"
  94. r_output_alias_each_other = output_alias_each_other(r)
  95. f_output_alias_each_other = output_alias_each_other(fake_r)
  96. assert (
  97. r_output_alias_each_other == f_output_alias_each_other
  98. ), f"Mismatch on {func}: {r_output_alias_each_other} != {f_output_alias_each_other}"
  99. for r_out, fake_out in zip(tree_flatten(r)[0], tree_flatten(fake_r)[0]):
  100. r_is_ten = isinstance(r_out, torch.Tensor)
  101. assert r_is_ten == isinstance(
  102. fake_out, torch.Tensor
  103. ), f"Mismatched number of tensor outputs on {func}"
  104. if r_is_ten:
  105. assert (
  106. r_out.requires_grad == fake_out.requires_grad
  107. ), f"Mismatch on {func}"
  108. if torch._C._has_storage(r_out):
  109. r_offset = r_out.storage_offset()
  110. f_offset = fake_out.storage_offset()
  111. assert (
  112. r_offset == f_offset
  113. ), f"Mismatch on {func}: {r_offset} != {f_offset}"
  114. try:
  115. torch._prims.utils.compare_tensor_meta(
  116. r_out, fake_out, check_strides=self.check_strides
  117. )
  118. except Exception as e:
  119. raise RuntimeError(f"Mismatch on {func}: {e}") from e
  120. return r