common_jit.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Torch
  2. import torch
  3. import torch.cuda
  4. import torch.jit
  5. import torch.jit._logging
  6. import torch.jit.frontend
  7. import torch.jit.quantized
  8. # Testing utils
  9. from torch.testing._internal.common_dtype import floating_and_complex_types_and
  10. from torch.testing._internal.common_utils import TestCase, \
  11. freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
  12. from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
  13. # Standard library
  14. from itertools import chain
  15. from typing import List, Union
  16. from torch._C import TensorType
  17. import io
  18. def check_output_types(self, func, ref_outputs, args, kwargs):
  19. graph = getattr(func, 'last_graph', None)
  20. types = [o.type() for o in graph.outputs()]
  21. self.assertTrue(len(types) == 1)
  22. t = types[0]
  23. torch._C._jit_assert_is_instance(ref_outputs, t)
  24. # Test names in this set are only checked for a single derivative
  25. nn_functional_single_grad = frozenset('test_nn_' + name for name in [
  26. 'pdist',
  27. 'multilabel_margin_loss',
  28. 'max_unpool3d',
  29. 'multi_margin_loss',
  30. 'binary_cross_entropy',
  31. 'binary_cross_entropy_size_average',
  32. 'ctc_loss',
  33. 'grid_sample',
  34. ])
  35. def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
  36. allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
  37. """Verifies a function performs identically to some reference implementation.
  38. Commonly, this is used to verify that a JIT implementation
  39. (output_func) matches the behavior of the eager implementation
  40. (reference_func).
  41. """
  42. kwargs = kwargs if kwargs else {}
  43. def allSum(vs):
  44. if isinstance(vs, torch.Tensor):
  45. vs = (vs,)
  46. return sum((i + 1) * v.sum()
  47. for i, v in enumerate(vs)
  48. if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
  49. def clone_tensor(t, preserve_requires_grad):
  50. require_grad = preserve_requires_grad and t.requires_grad
  51. return t.detach().clone().requires_grad_(require_grad)
  52. def clone_inputs(preserve_requires_grad: bool):
  53. inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
  54. for arg in args:
  55. if isinstance(arg, torch.Tensor):
  56. inputs.append(clone_tensor(arg, preserve_requires_grad))
  57. elif is_iterable_of_tensors(arg):
  58. inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
  59. else:
  60. inputs.append(arg)
  61. return inputs
  62. # Returns tensors in args that requires_grad, including tensors in TensorList args
  63. def get_recording_tensors(args):
  64. recording_tensors: List[torch.Tensor] = []
  65. for arg in args:
  66. if isinstance(arg, torch.Tensor) and arg.requires_grad:
  67. recording_tensors.append(arg)
  68. elif is_iterable_of_tensors(arg):
  69. recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
  70. return recording_tensors
  71. # test no gradients case
  72. nograd_inputs = clone_inputs(preserve_requires_grad=False)
  73. outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
  74. with enable_profiling_mode_for_profiling_tests():
  75. outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
  76. self.assertEqual(outputs, outputs_test)
  77. if check_types:
  78. check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
  79. if no_grad:
  80. # skip grad tests
  81. return
  82. with enable_profiling_mode_for_profiling_tests():
  83. # test single grad case
  84. recording_inputs = clone_inputs(preserve_requires_grad=True)
  85. recording_tensors = get_recording_tensors(recording_inputs)
  86. outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
  87. grads = torch.autograd.grad(allSum(outputs), recording_tensors,
  88. allow_unused=allow_unused)
  89. outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
  90. grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
  91. allow_unused=allow_unused)
  92. self.assertEqual(outputs, outputs_test)
  93. self.assertEqual(grads, grads_test)
  94. # test the grad grad case
  95. if self._testMethodName in nn_functional_single_grad or no_gradgrad:
  96. return
  97. outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
  98. l1 = allSum(outputs)
  99. grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
  100. allow_unused=allow_unused)
  101. l2 = (allSum(grads) * l1)
  102. grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
  103. recording_inputs = clone_inputs(preserve_requires_grad=True)
  104. recording_tensors = get_recording_tensors(recording_inputs)
  105. outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
  106. l1_test = allSum(outputs_test)
  107. grads_test = torch.autograd.grad(
  108. l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
  109. l2_test = (allSum(grads_test) * l1_test)
  110. grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
  111. self.assertEqual(outputs, outputs_test)
  112. self.assertEqual(grads, grads_test)
  113. for g2, g2_test in zip(grads2, grads2_test):
  114. if g2 is None and g2_test is None:
  115. continue
  116. self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
  117. class JitCommonTestCase(TestCase):
  118. def createFunctionFromGraph(self, trace):
  119. graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
  120. return torch._C._create_function_from_graph("forward", graph)
  121. def assertExportImport(self, trace, inputs):
  122. m = self.createFunctionFromGraph(trace)
  123. self.assertExportImportModule(m, inputs)
  124. def assertExportImportModule(self, m, inputs):
  125. m_import = self.getExportImportCopy(m)
  126. a = self.runAndSaveRNG(m, inputs)
  127. b = self.runAndSaveRNG(m_import, inputs)
  128. self.assertEqual(a, b, "Results of original model and "
  129. "exported/imported version of model differed")
  130. def runAndSaveRNG(self, func, inputs, kwargs=None):
  131. kwargs = kwargs if kwargs else {}
  132. with freeze_rng_state():
  133. results = func(*inputs, **kwargs)
  134. return results
  135. def getExportImportCopy(self, m, also_test_file=True, map_location=None):
  136. buffer = io.BytesIO()
  137. torch.jit.save(m, buffer)
  138. buffer.seek(0)
  139. imported = torch.jit.load(buffer, map_location=map_location)
  140. if not also_test_file:
  141. return imported
  142. with TemporaryFileName() as fname:
  143. torch.jit.save(imported, fname)
  144. return torch.jit.load(fname, map_location=map_location)
  145. def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
  146. fusion_nodes_not_found, non_fusible_nodes_being_fused,
  147. fusion_nodes_found, nodes_in_diff_graph):
  148. err_msg = "\nFailure in testing nodes' autodifferentiation. "
  149. if should_autodiff_node:
  150. err_msg += "One or more nodes were expected to be autodiffed, " \
  151. "but were not found in specified fusible/nonfusible " \
  152. "DifferentiableGraph groups. \nSpecifically:"
  153. # The node is intended to appear in a differentiable graph but doesn't
  154. diff_nodes_missing = []
  155. # The node is intended to appear in a differentiable graph
  156. # outside of a fusion group but instead is in a fusion group
  157. diff_nodes_in_fusion = []
  158. # The node is intended to appear in a fusion group but doesn't
  159. fusion_nodes_missing = []
  160. # The node is intended to appear in a fusion group but instead
  161. # is just in an outer differentiable graph
  162. fusion_nodes_in_diff = []
  163. for node in nodes_not_in_diff_graph:
  164. if node in non_fusible_nodes_being_fused:
  165. diff_nodes_in_fusion.append(node)
  166. else:
  167. diff_nodes_missing.append(node)
  168. for node in fusion_nodes_not_found:
  169. if node in nodes_in_diff_graph:
  170. fusion_nodes_in_diff.append(node)
  171. else:
  172. fusion_nodes_missing.append(node)
  173. if len(diff_nodes_missing) > 0:
  174. err_msg += f"\n {diff_nodes_missing} were not in one of the " \
  175. "DifferentiableGraphs when they were expected to be. " \
  176. "Did you intend for these nodes to be autodiffed? " \
  177. "If not, remove them from the list of nonfusible nodes."
  178. if len(diff_nodes_in_fusion) > 0:
  179. err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \
  180. "when they were expected to be just in a DifferentiableGraph. If it was " \
  181. "intended for these nodes to be in FusionGroups, reclassify these nodes as " \
  182. "fusible nodes. If these nodes were not intended to be fused, your " \
  183. "autodifferentiation logic might be wrong."
  184. if len(fusion_nodes_missing) > 0:
  185. err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \
  186. "of the DifferentiableGraphs when they were expected to be. " \
  187. "They were also not found in an outer DifferentiableGraph. Did you " \
  188. "intend for these nodes to be autodifferentiated? If not, you should " \
  189. "remove these nodes from the test's fusible nodes. Otherwise your " \
  190. "autodifferentiation logic might be wrong."
  191. if len(fusion_nodes_in_diff) > 0:
  192. err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \
  193. "of the DifferentiableGraphs when they were expected to be, " \
  194. "instead they were found just in an outer DifferentiableGraph. " \
  195. "Did you intend for these nodes to be fused? If not, you should " \
  196. "move these nodes into the test's nonfusible nodes. Otherwise your " \
  197. "autodifferentiation logic might be wrong."
  198. else:
  199. err_msg += "One or more nodes were not expected to be autodiffed " \
  200. "but were found in a DifferentiableGraph or in a FusionGroup " \
  201. "of a DifferentiableGraph. Did you intend for these nodes to be " \
  202. "autodiffed? If so, change this test to expect autodifferentiation. " \
  203. "\nSpecifically:"
  204. if len(fusion_nodes_found) > 0:
  205. err_msg += f"\n {fusion_nodes_found} were not expected to be in " \
  206. "one of the DifferentiableGraphs, but appeared in a FusionGroup " \
  207. "of a DifferentiableGraph. "
  208. if len(nodes_in_diff_graph) > 0:
  209. err_msg += f"\n {nodes_in_diff_graph} were not expected to " \
  210. "be in one of the DifferentiableGraphs but were."
  211. return err_msg
  212. def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
  213. diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
  214. diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
  215. # Note: currently no tests have fusible_nodes
  216. fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
  217. fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
  218. # For any non-fusible node, it must show up in one of the DifferentiableGraphs.
  219. nodes_in_diff_graph = []
  220. nodes_not_in_diff_graph = []
  221. non_fusible_nodes_being_fused = []
  222. for node in nonfusible_nodes:
  223. if any(g.findNode(node) is not None for g in diff_subgraphs):
  224. nodes_in_diff_graph.append(node)
  225. else:
  226. nodes_not_in_diff_graph.append(node)
  227. if any(g.findNode(node) is not None for g in fusion_subgraphs):
  228. non_fusible_nodes_being_fused.append(node)
  229. found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
  230. # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
  231. fusion_nodes_found = []
  232. fusion_nodes_not_found = []
  233. for node in fusible_nodes:
  234. if any(g.findNode(node) is not None for g in fusion_subgraphs):
  235. fusion_nodes_found.append(node)
  236. else:
  237. fusion_nodes_not_found.append(node)
  238. found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
  239. if should_autodiff_node is not None:
  240. err_msg = self.autoDiffErrorMessage(should_autodiff_node,
  241. nodes_not_in_diff_graph,
  242. fusion_nodes_not_found,
  243. non_fusible_nodes_being_fused,
  244. fusion_nodes_found,
  245. nodes_in_diff_graph)
  246. self.assertEqual(should_autodiff_node,
  247. found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
  248. def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
  249. traced_graph, assert_propagation, constant_prop=True):
  250. # repropagte input shapes provided by tracing,
  251. prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
  252. for enable_test_mode in [True, False]:
  253. # here we are testing allowing/disallowing substituting in complete shapes as constants,
  254. # disallowing constants helps stress test partial eval and substitution pipeline
  255. torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
  256. torch._C._jit_erase_non_input_shape_information(traced_graph)
  257. if constant_prop:
  258. torch._C._jit_pass_constant_propagation(traced_graph)
  259. torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
  260. # Add sizes to default tensor type to avoid checking something out of scope
  261. # and difficulties with tracer leaving in other parts of tensor type
  262. output = next(traced_graph.outputs()).type()
  263. def test_type(type, actual_size):
  264. sizes = type.symbolic_sizes()
  265. out_type = TensorType.get().with_sizes(sizes)
  266. actual_type = TensorType.get().with_sizes(actual_size)
  267. # always check actual shape is a subtype of the output
  268. self.assertTrue(actual_type.isSubtypeOf(out_type))
  269. # and then if assertion flag is provided, check shape analysis
  270. # is successful
  271. if assert_propagation:
  272. self.assertEqual(out_type.sizes(), actual_size)
  273. if output.isSubtypeOf(torch._C.TensorType.get()):
  274. test_type(output, out_sizes)
  275. else:
  276. tuple_elements = output.elements()
  277. for i in range(len(tuple_elements)):
  278. test_type(tuple_elements[i], out_sizes[i])
  279. torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)