functionalization.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from typing import List, Optional
  2. from torchgen.api import dispatcher
  3. from torchgen.api.types import (
  4. BaseCType,
  5. Binding,
  6. boolT,
  7. ConstRefCType,
  8. CType,
  9. longT,
  10. NamedCType,
  11. tensorT,
  12. )
  13. from torchgen.model import (
  14. Argument,
  15. BaseTy,
  16. BaseType,
  17. FunctionSchema,
  18. NativeFunctionsViewGroup,
  19. )
  20. # This file describes the translation of JIT schema to API's used
  21. # when creating view lambdas that are used by the functionalization pass.
  22. # There are two types of lambdas: forward lambdas and reverse lambdas.
  23. # These API's mostly follow the dispatcher API, with a few quirks:
  24. # - The lambda capture has to convert reference types to value types
  25. # - While the forward lambda just directly calls into the at::_ops API
  26. # (following the dispatcher convention), the logic here for the reverse lambda
  27. # is responsible for generating both the call-site, and the declarations
  28. # (which are implemented manually in the at::functionalization::impl namespace).
  29. # The lambdas generated for each view op in the functionalization pass are of the form
  30. # [capture_arguments](outer_arguments) -> returns_type {
  31. # return name(inner_arguments);
  32. # }
  33. # Define some specific lambda input arguments.
  34. base_binding = Binding(
  35. name="base",
  36. nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
  37. argument=Argument(
  38. name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
  39. ),
  40. default=None,
  41. )
  42. mutated_view_binding = Binding(
  43. name="mutated_view",
  44. nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
  45. argument=Argument(
  46. name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
  47. ),
  48. default=None,
  49. )
  50. mutated_view_idx_binding = Binding(
  51. name="mutated_view_idx",
  52. nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
  53. argument=Argument(
  54. name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
  55. ),
  56. default=None,
  57. )
  58. reapply_views_binding = Binding(
  59. name="reapply_views",
  60. nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
  61. argument=Argument(
  62. name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
  63. ),
  64. default=None,
  65. )
  66. # The lambda capture itself doesn't have a name.
  67. # The name returned here corresponds to the name of the inner function called by the lambda.
  68. def name(
  69. g: NativeFunctionsViewGroup,
  70. *,
  71. is_reverse: bool,
  72. include_namespace: bool,
  73. reapply_views: Optional[bool] = None,
  74. ) -> str:
  75. if reapply_views is None:
  76. # reapply_views is only important for the fwd lambda,
  77. # since we always plumb the runtime "reapply_views" argument into the reverse function.
  78. assert is_reverse
  79. if is_reverse:
  80. # for the reverse: the name of the inverse function always involves "view_copy",
  81. # and we plumb the "reapply_views" flag into that function.
  82. # (We could avoid doing that, but that would require writing out twice as many view inverse functions).
  83. assert g.view_copy is not None
  84. api_name = g.view_copy.func.name.unambiguous_name()
  85. # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
  86. if include_namespace:
  87. return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
  88. else:
  89. return f"{api_name}_inverse"
  90. # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
  91. assert include_namespace
  92. assert g.view_copy is not None
  93. api_name = (
  94. g.view.func.name.unambiguous_name()
  95. if reapply_views
  96. else g.view_copy.func.name.unambiguous_name()
  97. )
  98. return f"at::_ops::{api_name}::call"
  99. def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
  100. # capture arguments include all arguments except `self`.
  101. # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
  102. # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
  103. args = func.arguments.flat_all
  104. assert args[0].type == BaseType(BaseTy.Tensor)
  105. non_self_args = args[1:]
  106. non_self_value_bindings = [
  107. dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
  108. ]
  109. all_bindings = [reapply_views_binding] + non_self_value_bindings
  110. return all_bindings
  111. def returns_type(func: FunctionSchema) -> CType:
  112. # Assertion: all view ops return tensor-like outputs
  113. assert len(func.returns) >= 1
  114. for ret in func.returns:
  115. assert ret.type.is_tensor_like()
  116. # However, the return type of the lambda is always an individual tensor.
  117. # For multi-tensor outputs, each tensor needs to be tracked individually.
  118. return BaseCType(tensorT)
  119. def outer_arguments(*, is_reverse: bool) -> List[Binding]:
  120. if is_reverse:
  121. return [base_binding, mutated_view_binding, mutated_view_idx_binding]
  122. else:
  123. return [base_binding, mutated_view_idx_binding]
  124. def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
  125. # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
  126. # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
  127. if len(func.returns) > 1 or (
  128. len(func.returns) == 1 and func.returns[0].type.is_list_like()
  129. ):
  130. return mutated_view_idx_binding
  131. return None
  132. def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
  133. args = func.arguments.flat_all
  134. assert args[0].type == BaseType(BaseTy.Tensor)
  135. non_self_args = args[1:]
  136. # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
  137. # Both of these follow the dispatcher API.
  138. non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
  139. if not is_reverse:
  140. # the forward lambda swaps out the original tensor argument with the lambd arg "base"
  141. return [base_binding] + non_self_bindings
  142. else:
  143. # the reverse lambda does the same, but with an additional "mutated_view" arg
  144. # additionally, we have a calling convention: for view ops that return multiple tensor outputs
  145. # their corresponding view_inverse function takes in an additional index argument.
  146. index_binding = inner_call_index(func)
  147. if index_binding is not None:
  148. return [
  149. base_binding,
  150. mutated_view_binding,
  151. reapply_views_binding,
  152. index_binding,
  153. ] + non_self_bindings
  154. else:
  155. return [
  156. base_binding,
  157. mutated_view_binding,
  158. reapply_views_binding,
  159. ] + non_self_bindings