utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. import contextlib
  2. import functools
  3. import hashlib
  4. import os
  5. import re
  6. import sys
  7. import textwrap
  8. from argparse import Namespace
  9. from dataclasses import fields, is_dataclass
  10. from enum import auto, Enum
  11. from typing import (
  12. Any,
  13. Callable,
  14. Dict,
  15. Generic,
  16. Iterable,
  17. Iterator,
  18. List,
  19. NoReturn,
  20. Optional,
  21. Sequence,
  22. Set,
  23. Tuple,
  24. TypeVar,
  25. Union,
  26. )
  27. from typing_extensions import Literal # Python 3.8+
  28. from torchgen.code_template import CodeTemplate
  29. # Safely load fast C Yaml loader/dumper if they are available
  30. try:
  31. from yaml import CSafeLoader as Loader
  32. except ImportError:
  33. from yaml import SafeLoader as Loader # type: ignore[misc]
  34. try:
  35. from yaml import CSafeDumper as Dumper
  36. except ImportError:
  37. from yaml import SafeDumper as Dumper # type: ignore[misc]
  38. YamlDumper = Dumper
  39. # A custom loader for YAML that errors on duplicate keys.
  40. # This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
  41. class YamlLoader(Loader):
  42. def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
  43. mapping = []
  44. for key_node, value_node in node.value:
  45. key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
  46. assert (
  47. key not in mapping
  48. ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
  49. mapping.append(key)
  50. mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
  51. return mapping
  52. # Many of these functions share logic for defining both the definition
  53. # and declaration (for example, the function signature is the same), so
  54. # we organize them into one function that takes a Target to say which
  55. # code we want.
  56. #
  57. # This is an OPEN enum (we may add more cases to it in the future), so be sure
  58. # to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
  59. # what targets are valid for your use.
  60. class Target(Enum):
  61. # top level namespace (not including at)
  62. DEFINITION = auto()
  63. DECLARATION = auto()
  64. # TORCH_LIBRARY(...) { ... }
  65. REGISTRATION = auto()
  66. # namespace { ... }
  67. ANONYMOUS_DEFINITION = auto()
  68. # namespace cpu { ... }
  69. NAMESPACED_DEFINITION = auto()
  70. NAMESPACED_DECLARATION = auto()
  71. # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
  72. # occurrence of a parameter in the derivative formula
  73. IDENT_REGEX = r"(^|\W){}($|\W)"
  74. # TODO: Use a real parser here; this will get bamboozled
  75. def split_name_params(schema: str) -> Tuple[str, List[str]]:
  76. m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
  77. if m is None:
  78. raise RuntimeError(f"Unsupported function schema: {schema}")
  79. name, _, params = m.groups()
  80. return name, params.split(", ")
  81. T = TypeVar("T")
  82. S = TypeVar("S")
  83. # These two functions purposely return generators in analogy to map()
  84. # so that you don't mix up when you need to list() them
  85. # Map over function that may return None; omit Nones from output sequence
  86. def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
  87. for x in xs:
  88. r = func(x)
  89. if r is not None:
  90. yield r
  91. # Map over function that returns sequences and cat them all together
  92. def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
  93. for x in xs:
  94. for r in func(x):
  95. yield r
  96. # Conveniently add error context to exceptions raised. Lets us
  97. # easily say that an error occurred while processing a specific
  98. # context.
  99. @contextlib.contextmanager
  100. def context(msg_fn: Callable[[], str]) -> Iterator[None]:
  101. try:
  102. yield
  103. except Exception as e:
  104. # TODO: this does the wrong thing with KeyError
  105. msg = msg_fn()
  106. msg = textwrap.indent(msg, " ")
  107. msg = f"{e.args[0]}\n{msg}" if e.args else msg
  108. e.args = (msg,) + e.args[1:]
  109. raise
  110. # A little trick from https://github.com/python/mypy/issues/6366
  111. # for getting mypy to do exhaustiveness checking
  112. # TODO: put this somewhere else, maybe
  113. def assert_never(x: NoReturn) -> NoReturn:
  114. raise AssertionError("Unhandled type: {}".format(type(x).__name__))
  115. @functools.lru_cache(maxsize=None)
  116. def _read_template(template_fn: str) -> CodeTemplate:
  117. return CodeTemplate.from_file(template_fn)
  118. # String hash that's stable across different executions, unlike builtin hash
  119. def string_stable_hash(s: str) -> int:
  120. sha1 = hashlib.sha1(s.encode("latin1")).digest()
  121. return int.from_bytes(sha1, byteorder="little")
  122. # A small abstraction for writing out generated files and keeping track
  123. # of what files have been written (so you can write out a list of output
  124. # files)
  125. class FileManager:
  126. install_dir: str
  127. template_dir: str
  128. dry_run: bool
  129. filenames: Set[str]
  130. def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
  131. self.install_dir = install_dir
  132. self.template_dir = template_dir
  133. self.filenames = set()
  134. self.dry_run = dry_run
  135. def _write_if_changed(self, filename: str, contents: str) -> None:
  136. old_contents: Optional[str]
  137. try:
  138. with open(filename, "r") as f:
  139. old_contents = f.read()
  140. except IOError:
  141. old_contents = None
  142. if contents != old_contents:
  143. # Create output directory if it doesn't exist
  144. os.makedirs(os.path.dirname(filename), exist_ok=True)
  145. with open(filename, "w") as f:
  146. f.write(contents)
  147. # Read from template file and replace pattern with callable (type could be dict or str).
  148. def substitute_with_template(
  149. self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
  150. ) -> str:
  151. template_path = os.path.join(self.template_dir, template_fn)
  152. env = env_callable()
  153. if isinstance(env, dict):
  154. # TODO: Update the comment reference to the correct location
  155. if "generated_comment" not in env:
  156. comment = "@" + "generated by torchgen/gen.py"
  157. comment += " from {}".format(os.path.basename(template_path))
  158. env["generated_comment"] = comment
  159. template = _read_template(template_path)
  160. return template.substitute(env)
  161. elif isinstance(env, str):
  162. return env
  163. else:
  164. assert_never(env)
  165. def write_with_template(
  166. self,
  167. filename: str,
  168. template_fn: str,
  169. env_callable: Callable[[], Union[str, Dict[str, Any]]],
  170. ) -> None:
  171. filename = "{}/{}".format(self.install_dir, filename)
  172. assert filename not in self.filenames, "duplicate file write {filename}"
  173. self.filenames.add(filename)
  174. if not self.dry_run:
  175. substitute_out = self.substitute_with_template(
  176. template_fn=template_fn,
  177. env_callable=env_callable,
  178. )
  179. self._write_if_changed(filename=filename, contents=substitute_out)
  180. def write(
  181. self,
  182. filename: str,
  183. env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]],
  184. ) -> None:
  185. self.write_with_template(filename, filename, env_callable)
  186. def write_sharded(
  187. self,
  188. filename: str,
  189. items: Iterable[T],
  190. *,
  191. key_fn: Callable[[T], str],
  192. env_callable: Callable[[T], Dict[str, List[str]]],
  193. num_shards: int,
  194. base_env: Optional[Dict[str, Any]] = None,
  195. sharded_keys: Set[str],
  196. ) -> None:
  197. everything: Dict[str, Any] = {"shard_id": "Everything"}
  198. shards: List[Dict[str, Any]] = [
  199. {"shard_id": f"_{i}"} for i in range(num_shards)
  200. ]
  201. all_shards = [everything] + shards
  202. if base_env is not None:
  203. for shard in all_shards:
  204. shard.update(base_env)
  205. for key in sharded_keys:
  206. for shard in all_shards:
  207. if key in shard:
  208. assert isinstance(
  209. shard[key], list
  210. ), "sharded keys in base_env must be a list"
  211. shard[key] = shard[key].copy()
  212. else:
  213. shard[key] = []
  214. def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
  215. for k, v in from_.items():
  216. assert k in sharded_keys, f"undeclared sharded key {k}"
  217. into[k] += v
  218. if self.dry_run:
  219. # Dry runs don't write any templates, so incomplete environments are fine
  220. items = ()
  221. for item in items:
  222. key = key_fn(item)
  223. sid = string_stable_hash(key) % num_shards
  224. env = env_callable(item)
  225. merge_env(shards[sid], env)
  226. merge_env(everything, env)
  227. dot_pos = filename.rfind(".")
  228. if dot_pos == -1:
  229. dot_pos = len(filename)
  230. base_filename = filename[:dot_pos]
  231. extension = filename[dot_pos:]
  232. for shard in all_shards:
  233. shard_id = shard["shard_id"]
  234. self.write_with_template(
  235. f"{base_filename}{shard_id}{extension}", filename, lambda: shard
  236. )
  237. # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
  238. self.filenames.discard(
  239. f"{self.install_dir}/{base_filename}Everything{extension}"
  240. )
  241. def write_outputs(self, variable_name: str, filename: str) -> None:
  242. """Write a file containing the list of all outputs which are
  243. generated by this script."""
  244. content = "set({}\n {})".format(
  245. variable_name,
  246. "\n ".join('"' + name + '"' for name in sorted(self.filenames)),
  247. )
  248. self._write_if_changed(filename, content)
  249. def template_dir_for_comments(self) -> str:
  250. """
  251. This needs to be deterministic. The template dir is an absolute path
  252. that varies across builds. So, just use the path relative to this file,
  253. which will point to the codegen source but will be stable.
  254. """
  255. return os.path.relpath(self.template_dir, os.path.dirname(__file__))
  256. # Helper function to generate file manager
  257. def make_file_manager(
  258. options: Namespace, install_dir: Optional[str] = None
  259. ) -> FileManager:
  260. template_dir = os.path.join(options.source_path, "templates")
  261. install_dir = install_dir if install_dir else options.install_dir
  262. return FileManager(
  263. install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
  264. )
  265. # Helper function to create a pretty representation for dataclasses
  266. def dataclass_repr(
  267. obj: Any,
  268. indent: int = 0,
  269. width: int = 80,
  270. ) -> str:
  271. # built-in pprint module support dataclasses from python 3.10
  272. if sys.version_info >= (3, 10):
  273. from pprint import pformat
  274. return pformat(obj, indent, width)
  275. return _pformat(obj, indent=indent, width=width)
  276. def _pformat(
  277. obj: Any,
  278. indent: int,
  279. width: int,
  280. curr_indent: int = 0,
  281. ) -> str:
  282. assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
  283. class_name = obj.__class__.__name__
  284. # update current indentation level with class name
  285. curr_indent += len(class_name) + 1
  286. fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
  287. fields_str = []
  288. for name, attr in fields_list:
  289. # update the current indent level with the field name
  290. # dict, list, set and tuple also add indent as done in pprint
  291. _curr_indent = curr_indent + len(name) + 1
  292. if is_dataclass(attr):
  293. str_repr = _pformat(attr, indent, width, _curr_indent)
  294. elif isinstance(attr, dict):
  295. str_repr = _format_dict(attr, indent, width, _curr_indent)
  296. elif isinstance(attr, (list, set, tuple)):
  297. str_repr = _format_list(attr, indent, width, _curr_indent)
  298. else:
  299. str_repr = repr(attr)
  300. fields_str.append(f"{name}={str_repr}")
  301. indent_str = curr_indent * " "
  302. body = f",\n{indent_str}".join(fields_str)
  303. return f"{class_name}({body})"
  304. def _format_dict(
  305. attr: Dict[Any, Any],
  306. indent: int,
  307. width: int,
  308. curr_indent: int,
  309. ) -> str:
  310. curr_indent += indent + 3
  311. dict_repr = []
  312. for k, v in attr.items():
  313. k_repr = repr(k)
  314. v_str = (
  315. _pformat(v, indent, width, curr_indent + len(k_repr))
  316. if is_dataclass(v)
  317. else repr(v)
  318. )
  319. dict_repr.append(f"{k_repr}: {v_str}")
  320. return _format(dict_repr, indent, width, curr_indent, "{", "}")
  321. def _format_list(
  322. attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
  323. indent: int,
  324. width: int,
  325. curr_indent: int,
  326. ) -> str:
  327. curr_indent += indent + 1
  328. list_repr = [
  329. _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
  330. for l in attr
  331. ]
  332. start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
  333. return _format(list_repr, indent, width, curr_indent, start, end)
  334. def _format(
  335. fields_str: List[str],
  336. indent: int,
  337. width: int,
  338. curr_indent: int,
  339. start: str,
  340. end: str,
  341. ) -> str:
  342. delimiter, curr_indent_str = "", ""
  343. # if it exceed the max width then we place one element per line
  344. if len(repr(fields_str)) >= width:
  345. delimiter = "\n"
  346. curr_indent_str = " " * curr_indent
  347. indent_str = " " * indent
  348. body = f", {delimiter}{curr_indent_str}".join(fields_str)
  349. return f"{start}{indent_str}{body}{end}"
  350. class NamespaceHelper:
  351. """A helper for constructing the namespace open and close strings for a nested set of namespaces.
  352. e.g. for namespace_str torch::lazy,
  353. prologue:
  354. namespace torch {
  355. namespace lazy {
  356. epilogue:
  357. } // namespace lazy
  358. } // namespace torch
  359. """
  360. def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
  361. # cpp_namespace can be a colon joined string such as torch::lazy
  362. cpp_namespaces = namespace_str.split("::")
  363. assert (
  364. len(cpp_namespaces) <= max_level
  365. ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
  366. self.cpp_namespace_ = namespace_str
  367. self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
  368. self.epilogue_ = "\n".join(
  369. [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
  370. )
  371. self.namespaces_ = cpp_namespaces
  372. self.entity_name_ = entity_name
  373. @staticmethod
  374. def from_namespaced_entity(
  375. namespaced_entity: str, max_level: int = 2
  376. ) -> "NamespaceHelper":
  377. """
  378. Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
  379. """
  380. names = namespaced_entity.split("::")
  381. entity_name = names[-1]
  382. namespace_str = "::".join(names[:-1])
  383. return NamespaceHelper(
  384. namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
  385. )
  386. @property
  387. def prologue(self) -> str:
  388. return self.prologue_
  389. @property
  390. def epilogue(self) -> str:
  391. return self.epilogue_
  392. @property
  393. def entity_name(self) -> str:
  394. return self.entity_name_
  395. # Only allow certain level of namespaces
  396. def get_cpp_namespace(self, default: str = "") -> str:
  397. """
  398. Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
  399. Return default if namespace string is empty.
  400. """
  401. return self.cpp_namespace_ if self.cpp_namespace_ else default
  402. class OrderedSet(Generic[T]):
  403. storage: Dict[T, Literal[None]]
  404. def __init__(self, iterable: Optional[Iterable[T]] = None):
  405. if iterable is None:
  406. self.storage = {}
  407. else:
  408. self.storage = {k: None for k in iterable}
  409. def __contains__(self, item: T) -> bool:
  410. return item in self.storage
  411. def __iter__(self) -> Iterator[T]:
  412. return iter(self.storage.keys())
  413. def update(self, items: "OrderedSet[T]") -> None:
  414. self.storage.update(items.storage)
  415. def add(self, item: T) -> None:
  416. self.storage[item] = None
  417. def copy(self) -> "OrderedSet[T]":
  418. ret: OrderedSet[T] = OrderedSet()
  419. ret.storage = self.storage.copy()
  420. return ret
  421. @staticmethod
  422. def union(*args: "OrderedSet[T]") -> "OrderedSet[T]":
  423. ret = args[0].copy()
  424. for s in args[1:]:
  425. ret.update(s)
  426. return ret
  427. def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
  428. return OrderedSet.union(self, other)
  429. def __ior__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
  430. self.update(other)
  431. return self
  432. def __eq__(self, other: object) -> bool:
  433. if isinstance(other, OrderedSet):
  434. return self.storage == other.storage
  435. else:
  436. return set(self.storage.keys()) == other