schema_check_mode.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import torch
  2. from torch.utils._pytree import tree_flatten, tree_map
  3. from torch.fx.operator_schemas import normalize_function
  4. from torch.testing._internal.jit_utils import clone_inputs
  5. from torch.utils._python_dispatch import TorchDispatchMode
  6. from itertools import combinations
  7. from collections import namedtuple
  8. from copy import deepcopy
  9. # Named Tuples used within SchemaCheckMode
  10. Mutation = namedtuple('Mutation', ['op_name', 'arg_name'])
  11. Aliasing = namedtuple('Aliasing', ['op_name', 'arg_name', 'output_number'])
  12. # Simplified naming for C++ classes
  13. SchemaArgument = torch._C._SchemaArgument
  14. SchemaArgType = torch._C._SchemaArgType
  15. SchemaInfo = torch._C._SchemaInfo
  16. # This TorchDispatchMode Subclass is used to verify op schemas
  17. # This TorchDispatchMode Scubclass currently:
  18. # - Records the called ops
  19. # - Checks for mutations on all inputs
  20. # - Checks for aliasing on all inputs
  21. class SchemaCheckMode(TorchDispatchMode):
  22. def __init__(self):
  23. # Information recorded for testing purposes. For example:
  24. # - incorrect schemas
  25. # - overly conservative schemas
  26. self.ops = []
  27. self.mutated = []
  28. self.aliasing = []
  29. def reset_cache(self):
  30. self.ops.clear()
  31. self.mutated.clear()
  32. self.aliasing.clear()
  33. def display_ops(self):
  34. print(*self.ops, sep=",")
  35. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  36. def has_mutated(before, after, md):
  37. are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
  38. if are_tensors and before.layout != torch.sparse_csr and after.layout != torch.sparse_csr:
  39. return not (
  40. before.size() == after.size() and
  41. torch.allclose(before, after, equal_nan=True) and
  42. md[0] == after.stride() and
  43. md[1] == after._typed_storage()._cdata
  44. )
  45. return False
  46. def has_aliased(lhs, rhs):
  47. try:
  48. return torch._C._overlaps(lhs, rhs)
  49. except Exception as exception:
  50. if str(exception).startswith("Cannot inspect value of type "):
  51. return False
  52. else:
  53. raise exception
  54. def standardize_name(name):
  55. return name if name != "self" else "input"
  56. def unwrap(e):
  57. if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
  58. try:
  59. return e.elem
  60. except AttributeError as t:
  61. return e
  62. return e
  63. def parse_metadata(e):
  64. if isinstance(e, torch.Tensor):
  65. if not type(e) == torch.Tensor:
  66. try:
  67. current = e.elem
  68. return (deepcopy(current.stride()), current._typed_storage()._cdata)
  69. except AttributeError as t:
  70. return None
  71. # Sparse CSR tensors do not have strides or storage
  72. elif (e.layout != torch.sparse_csr):
  73. return (deepcopy(e.stride()), e._typed_storage()._cdata)
  74. return None
  75. self.ops.append(func._schema.name)
  76. # Clone and process arguments and outputs
  77. pre_arguments = normalize_function(
  78. func,
  79. args,
  80. kwargs,
  81. normalize_to_only_use_kwargs=True
  82. ).kwargs
  83. c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
  84. cloned_arguments = {name : tree_map(unwrap, c_p_args.get(name)) for name in c_p_args}
  85. cloned_metadata = {name : tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments}
  86. out = func(*args, **kwargs)
  87. arguments = {name : tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments}
  88. tuple_out = out if isinstance(out, tuple) else (out, )
  89. tuple_out = tree_map(unwrap, tuple_out)
  90. schema_info = SchemaInfo(func._schema)
  91. schema_info.add_argument_values(pre_arguments)
  92. # Process arguments with outputs
  93. for i in range(len(func._schema.arguments)):
  94. arg = func._schema.arguments[i]
  95. name = standardize_name(arg.name)
  96. if arguments.get(name) is not None:
  97. before = cloned_arguments.get(name)
  98. md = cloned_metadata.get(name)
  99. after = arguments.get(name)
  100. for j in range(len(tuple_out)):
  101. # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
  102. unsafe_ops = ('aten::_unsafe_view', 'aten::unsafe_split')
  103. if has_aliased(tuple_out[j], after) and func._schema.name not in unsafe_ops:
  104. if not schema_info.may_contain_alias(
  105. SchemaArgument(SchemaArgType.output, j),
  106. SchemaArgument(SchemaArgType.input, i)):
  107. raise RuntimeError(f'Argument {name} is not defined to alias output but was aliasing')
  108. else:
  109. self.aliasing.append(Aliasing(func._schema.name, name, f"output_{j}"))
  110. if any(has_mutated(a, b, c) for a, b, c in zip(tree_flatten(before)[0], tree_flatten(after)[0], md)):
  111. if not schema_info.is_mutable(SchemaArgument(SchemaArgType.input, i)):
  112. raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated")
  113. else:
  114. self.mutated.append(Mutation(func._schema.name, name))
  115. # Aliasing between outputs
  116. for i, j in combinations(range(len(func._schema.returns)), 2):
  117. if has_aliased(tuple_out[i], tuple_out[j]):
  118. if not schema_info.may_contain_alias(
  119. SchemaArgument(SchemaArgType.output, i),
  120. SchemaArgument(SchemaArgType.output, j)):
  121. raise RuntimeError(f'Outputs {i} and {j} alias unexpectedly')
  122. return out