selector.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Optional, Set, Tuple
  3. import yaml
  4. from torchgen.model import NativeFunction
  5. from torchgen.selective_build.operator import (
  6. merge_debug_info,
  7. merge_operator_dicts,
  8. SelectiveBuildOperator,
  9. strip_operator_overload_name,
  10. )
  11. # A SelectiveBuilder holds information extracted from the selective build
  12. # YAML specification.
  13. #
  14. # It includes information about the build's selectivity, the debug_info
  15. # associated with this selective build (opaque string), and the set of
  16. # operators that should be included in the build.
  17. #
  18. @dataclass(frozen=True)
  19. class SelectiveBuilder:
  20. # If true, then the build is not selective, and includes all
  21. # operators.
  22. include_all_operators: bool
  23. # Debug Information at the selective/custom build level.
  24. _debug_info: Optional[Tuple[str, ...]]
  25. # A dictionary of operator -> operator metadata.
  26. operators: Dict[str, SelectiveBuildOperator]
  27. # A dictionary of selected kernel tags and dtypes. Typically a
  28. # PyTorch Operator Kernel (function) may have many code paths
  29. # that are specialized for many many Tensor dtypes, so it's not
  30. # one per kernel function, but there could be many per kernel
  31. # function. The tag isn't a kernel function name, but some fragment
  32. # of the kernel function implementation itself.
  33. kernel_metadata: Dict[str, List[str]]
  34. # A set of all the custom torch bind classes used by the selected models
  35. # Stored as a set internally to remove duplicates proactively, but written
  36. # as a list to yamls
  37. custom_classes: Set[str]
  38. # A set of all the build features used by the selected models
  39. # Stored as a set internally to remove duplicates proactively, but written
  40. # as a list to yamls
  41. build_features: Set[str]
  42. # If true, then fragments for all dtypes for all kernel functions
  43. # are included as well as all custom classes. This is typically set when any one of the
  44. # operator lists is generated from a mechanism other than
  45. # tracing based selective build.
  46. include_all_non_op_selectives: bool
  47. @staticmethod
  48. def get_nop_selector() -> "SelectiveBuilder":
  49. return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
  50. @staticmethod
  51. def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder":
  52. valid_top_level_keys = {
  53. "include_all_non_op_selectives",
  54. "include_all_operators",
  55. "debug_info",
  56. "operators",
  57. "kernel_metadata",
  58. "custom_classes",
  59. "build_features",
  60. }
  61. top_level_keys = set(data.keys())
  62. if len(top_level_keys - valid_top_level_keys) > 0:
  63. raise Exception(
  64. "Got unexpected top level keys: {}".format(
  65. ",".join(top_level_keys - valid_top_level_keys),
  66. )
  67. )
  68. include_all_operators = data.get("include_all_operators", False)
  69. assert isinstance(include_all_operators, bool)
  70. debug_info = None
  71. if "debug_info" in data:
  72. di_list = data["debug_info"]
  73. assert isinstance(di_list, list)
  74. debug_info = tuple(map(lambda x: str(x), di_list))
  75. operators = {}
  76. operators_dict = data.get("operators", {})
  77. assert isinstance(operators_dict, dict)
  78. for (k, v) in operators_dict.items():
  79. operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
  80. kernel_metadata = {}
  81. kernel_metadata_dict = data.get("kernel_metadata", {})
  82. assert isinstance(kernel_metadata_dict, dict)
  83. for (k, v) in kernel_metadata_dict.items():
  84. kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v))
  85. custom_classes = data.get("custom_classes", [])
  86. custom_classes = set(custom_classes) # type: ignore[arg-type]
  87. build_features = data.get("build_features", [])
  88. build_features = set(build_features) # type: ignore[arg-type]
  89. include_all_non_op_selectives = data.get("include_all_non_op_selectives", False)
  90. assert isinstance(include_all_non_op_selectives, bool)
  91. return SelectiveBuilder(
  92. include_all_operators,
  93. debug_info,
  94. operators,
  95. kernel_metadata,
  96. custom_classes, # type: ignore[arg-type]
  97. build_features, # type: ignore[arg-type]
  98. include_all_non_op_selectives,
  99. )
  100. @staticmethod
  101. def from_yaml_str(config_contents: str) -> "SelectiveBuilder":
  102. contents = yaml.safe_load(config_contents)
  103. return SelectiveBuilder.from_yaml_dict(contents)
  104. @staticmethod
  105. def from_yaml_path(config_path: str) -> "SelectiveBuilder":
  106. with open(config_path, "r") as f:
  107. contents = yaml.safe_load(f)
  108. return SelectiveBuilder.from_yaml_dict(contents)
  109. @staticmethod
  110. def from_legacy_op_registration_allow_list(
  111. allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool
  112. ) -> "SelectiveBuilder":
  113. operators = {}
  114. for op in allow_list:
  115. operators[op] = {
  116. "name": op,
  117. "is_root_operator": is_root_operator,
  118. "is_used_for_training": is_used_for_training,
  119. "include_all_overloads": True,
  120. }
  121. return SelectiveBuilder.from_yaml_dict(
  122. {
  123. "operators": operators,
  124. "include_all_non_op_selectives": True,
  125. }
  126. )
  127. def is_operator_selected(self, name: str) -> bool:
  128. if self.include_all_operators:
  129. return True
  130. if name in self.operators:
  131. return True
  132. name = strip_operator_overload_name(name)
  133. return name in self.operators and self.operators[name].include_all_overloads
  134. def is_native_function_selected(self, func: NativeFunction) -> bool:
  135. op_name = op_name_from_native_function(func)
  136. return self.is_operator_selected(op_name)
  137. def is_operator_selected_for_training(self, name: str) -> bool:
  138. if not self.is_operator_selected(name):
  139. return False
  140. if self.include_all_operators:
  141. return True
  142. not_training_op = SelectiveBuildOperator(
  143. name="",
  144. is_root_operator=False,
  145. is_used_for_training=False,
  146. include_all_overloads=False,
  147. _debug_info=None,
  148. )
  149. op = not_training_op
  150. if name in self.operators:
  151. op = self.operators[name]
  152. name = strip_operator_overload_name(name)
  153. base_op = not_training_op
  154. if name in self.operators:
  155. base_op = self.operators[name]
  156. return op.is_used_for_training or (
  157. base_op.include_all_overloads and base_op.is_used_for_training
  158. )
  159. def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
  160. op_name = op_name_from_native_function(func)
  161. return self.is_operator_selected_for_training(op_name)
  162. def is_root_operator(self, name: str) -> bool:
  163. if not self.is_operator_selected(name):
  164. return False
  165. if self.include_all_operators:
  166. return True
  167. if name in self.operators:
  168. op: SelectiveBuildOperator = self.operators[name]
  169. return op.is_root_operator
  170. name = strip_operator_overload_name(name)
  171. if name not in self.operators:
  172. return False
  173. base_op: SelectiveBuildOperator = self.operators[name]
  174. return base_op.include_all_overloads and base_op.is_root_operator
  175. def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
  176. if self.include_all_operators or self.include_all_non_op_selectives:
  177. return True
  178. return (
  179. kernel_tag in self.kernel_metadata
  180. and dtype in self.kernel_metadata[kernel_tag]
  181. )
  182. def to_dict(self) -> Dict[str, object]:
  183. ret: Dict[str, object] = {
  184. "include_all_non_op_selectives": self.include_all_non_op_selectives,
  185. "include_all_operators": self.include_all_operators,
  186. }
  187. operators = {}
  188. for (op_name, op) in self.operators.items():
  189. operators[op_name] = op.to_dict()
  190. ret["operators"] = operators
  191. if self._debug_info is not None:
  192. ret["debug_info"] = sorted(self._debug_info)
  193. ret["kernel_metadata"] = {
  194. k: sorted(v) for (k, v) in self.kernel_metadata.items()
  195. }
  196. ret["custom_classes"] = sorted(self.custom_classes)
  197. ret["build_features"] = sorted(self.build_features)
  198. return ret
  199. def merge_kernel_metadata(
  200. lhs: Dict[str, List[str]],
  201. rhs: Dict[str, List[str]],
  202. ) -> Dict[str, List[str]]:
  203. kernel_metadata: Dict[str, List[str]] = {}
  204. for (tag_name, dtypes) in list(lhs.items()) + list(rhs.items()):
  205. dtypes_copy = set(dtypes)
  206. if tag_name in kernel_metadata:
  207. dtypes_copy |= set(kernel_metadata[tag_name])
  208. kernel_metadata[tag_name] = list(dtypes_copy)
  209. return kernel_metadata
  210. def combine_selective_builders(
  211. lhs: SelectiveBuilder, rhs: SelectiveBuilder
  212. ) -> SelectiveBuilder:
  213. include_all_operators = lhs.include_all_operators or rhs.include_all_operators
  214. debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
  215. operators = merge_operator_dicts(lhs.operators, rhs.operators)
  216. kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
  217. include_all_non_op_selectives = (
  218. lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
  219. )
  220. custom_classes = lhs.custom_classes.union(rhs.custom_classes)
  221. build_features = lhs.build_features.union(rhs.build_features)
  222. return SelectiveBuilder(
  223. include_all_operators,
  224. debug_info,
  225. operators,
  226. kernel_metadata,
  227. custom_classes,
  228. build_features,
  229. include_all_non_op_selectives,
  230. )
  231. def op_name_from_native_function(f: NativeFunction) -> str:
  232. # This was originally read from the 'operator_name_with_overload' field in the
  233. # declaration dict, which was the part before the first '(' in 'schema_string'.
  234. return f"{f.namespace}::{f.func.name}"