gen_vmap_plumbing.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import textwrap
  2. from dataclasses import dataclass
  3. from typing import List, Optional, Sequence, Tuple
  4. from torchgen.api.translate import translate
  5. from torchgen.api.types import DispatcherSignature
  6. from torchgen.context import method_with_native_function
  7. from torchgen.model import (
  8. Argument,
  9. BaseTy,
  10. BaseType,
  11. FunctionSchema,
  12. ListType,
  13. NativeFunction,
  14. OptionalType,
  15. Return,
  16. SchemaKind,
  17. Type,
  18. )
  19. from torchgen.utils import mapMaybe
  20. def is_tensor(typ: Type) -> bool:
  21. return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
  22. def is_optional_tensor(typ: Type) -> bool:
  23. return isinstance(typ, OptionalType) and is_tensor(typ.elem)
  24. def is_tensor_list(typ: Type) -> bool:
  25. return isinstance(typ, ListType) and is_tensor(typ.elem)
  26. def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
  27. result = f"""\
  28. Tensor {name}_value;
  29. optional<int64_t> {name}_bdim;
  30. std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, {cur_level_var});"""
  31. return textwrap.dedent(result).split("\n")
  32. def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
  33. result = f"""\
  34. optional<Tensor> {name}_value;
  35. optional<int64_t> {name}_bdim;
  36. if ({name}) {{
  37. std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
  38. }}"""
  39. return textwrap.dedent(result).split("\n")
  40. def gen_unwraps(
  41. flat_arguments: Sequence[Argument], cur_level_var: str
  42. ) -> Tuple[str, List[str]]:
  43. arg_names = [a.name for a in flat_arguments]
  44. arg_types = [a.type for a in flat_arguments]
  45. tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
  46. optional_tensors = [
  47. name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
  48. ]
  49. unwraps = []
  50. for tensor in tensors:
  51. unwraps += unwrap_tensor(tensor, cur_level_var)
  52. for opt_tensor in optional_tensors:
  53. unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
  54. unwrap_code = "\n".join(unwraps)
  55. unwrapped_arg_list = []
  56. for arg in arg_names:
  57. if arg in tensors or arg in optional_tensors:
  58. unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
  59. else:
  60. unwrapped_arg_list.append(arg)
  61. return unwrap_code, unwrapped_arg_list
  62. def gen_case_where_all_bdims_are_none(
  63. outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
  64. ) -> str:
  65. conditions = []
  66. flat_args = schema.arguments.flat_all
  67. for arg in flat_args:
  68. if not arg.type.is_tensor_like():
  69. continue
  70. conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
  71. sig = DispatcherSignature.from_schema(schema)
  72. translated_args = ", ".join(
  73. e.expr for e in translate(outer_sig.arguments(), sig.arguments())
  74. )
  75. return f"""\
  76. if ({' && '.join(conditions)}) {{
  77. return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
  78. }}"""
  79. def gen_returns(
  80. returns: Tuple[Return, ...], cur_level_var: str, results_var: str
  81. ) -> str:
  82. idx = 0
  83. wrapped_returns = []
  84. for ret in returns:
  85. if is_tensor(ret.type):
  86. wrapped_returns.append(
  87. f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
  88. )
  89. idx += 2
  90. elif is_tensor_list(ret.type):
  91. wrapped_returns.append(
  92. f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
  93. )
  94. idx += 2
  95. else:
  96. wrapped_returns.append(f"std::get<{idx}>({results_var})")
  97. idx += 1
  98. if len(wrapped_returns) == 1:
  99. result = f"return {wrapped_returns[0]};"
  100. else:
  101. result = f'return std::make_tuple({", ".join(wrapped_returns)});'
  102. return result
  103. def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
  104. return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
  105. def is_mutated_arg(argument: Argument) -> bool:
  106. return argument.annotation is not None and argument.annotation.is_write
  107. def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
  108. # Assumptions:
  109. # - only one argument is being modified in-place
  110. # - the argument that is being modified in-place is the first argument
  111. # - all returns are either Tensor, tuple of Tensor, or TensorList
  112. schema = native_function.func
  113. sig = DispatcherSignature.from_schema(schema)
  114. returns = schema.returns
  115. # Check assumptions. If these are invalid we return None
  116. # and punt the work to handle them to the future.
  117. assert schema.kind() == SchemaKind.inplace
  118. if not is_mutated_arg(schema.arguments.flat_all[0]):
  119. return None
  120. if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
  121. return None
  122. # Only support cases where all returns are Tensors or vector<Tensor>
  123. if len(returns) == 0:
  124. return None
  125. if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
  126. return None
  127. if not accepts_at_least_one_tensor_input(schema):
  128. return None
  129. cur_level_var = "cur_level"
  130. unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
  131. bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
  132. return f"""\
  133. template <typename batch_rule_t, batch_rule_t batch_rule>
  134. {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
  135. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  136. auto maybe_layer = maybeCurrentDynamicLayer();
  137. vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
  138. int64_t {cur_level_var} = maybe_layer->layerId();
  139. {textwrap.indent(bdims_all_none_case, " ")}
  140. {textwrap.indent(unwraps, " ")}
  141. batch_rule({', '.join(unwrapped_arg_list)});
  142. return {schema.arguments.flat_all[0].name};
  143. }}"""
  144. def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
  145. schema = native_function.func
  146. sig = DispatcherSignature.from_schema(schema)
  147. cur_level_var = "cur_level"
  148. unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
  149. bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
  150. return f"""\
  151. template <typename batch_rule_t, batch_rule_t batch_rule>
  152. {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
  153. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  154. auto maybe_layer = maybeCurrentDynamicLayer();
  155. vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
  156. int64_t {cur_level_var} = maybe_layer->layerId();
  157. {textwrap.indent(bdims_all_none_case, " ")}
  158. {textwrap.indent(unwraps, " ")}
  159. batch_rule({', '.join(unwrapped_arg_list)});
  160. }}"""
  161. def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
  162. schema = native_function.func
  163. sig = DispatcherSignature.from_schema(schema)
  164. returns = schema.returns
  165. # Only support cases where all returns are Tensors or vector<Tensor>
  166. if not accepts_at_least_one_tensor_input(schema):
  167. return None
  168. if len(returns) == 0:
  169. return gen_vmap_plumbing_no_returns(native_function)
  170. if not all(ret.type.is_tensor_like() for ret in returns):
  171. return None
  172. # in-place views need special handling
  173. if "inplace_view" in native_function.tags:
  174. return None
  175. if schema.kind() == SchemaKind.inplace:
  176. return gen_vmap_inplace_plumbing(native_function)
  177. # Don't support these (mutable, out, scratch)
  178. if schema.kind() != SchemaKind.functional:
  179. return None
  180. results_var = "results"
  181. cur_level_var = "cur_level"
  182. unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
  183. bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
  184. wrapped_returns = gen_returns(returns, cur_level_var, results_var)
  185. return f"""\
  186. template <typename batch_rule_t, batch_rule_t batch_rule>
  187. {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
  188. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  189. auto maybe_layer = maybeCurrentDynamicLayer();
  190. vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
  191. int64_t {cur_level_var} = maybe_layer->layerId();
  192. {textwrap.indent(bdims_all_none_case, " ")}
  193. {textwrap.indent(unwraps, " ")}
  194. auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
  195. {wrapped_returns}
  196. }}"""
  197. @dataclass(frozen=True)
  198. class ComputeBatchRulePlumbing:
  199. @method_with_native_function
  200. def __call__(self, f: NativeFunction) -> Optional[str]:
  201. opname = str(f.func.name)
  202. result = gen_vmap_plumbing(f)
  203. return result
  204. def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
  205. body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
  206. return f"""
  207. #pragma once
  208. #include <ATen/Operators.h>
  209. #include <ATen/functorch/PlumbingHelper.h>
  210. namespace at {{ namespace functorch {{
  211. {body}
  212. }}}} // namespace at::functorch
  213. """