gen_pyi.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import os
  2. import pathlib
  3. from typing import Any, Dict, List, Set, Tuple, Union
  4. def materialize_lines(lines: List[str], indentation: int) -> str:
  5. output = ""
  6. new_line_with_indent = "\n" + " " * indentation
  7. for i, line in enumerate(lines):
  8. if i != 0:
  9. output += new_line_with_indent
  10. output += line.replace('\n', new_line_with_indent)
  11. return output
  12. def gen_from_template(dir: str, template_name: str, output_name: str, replacements: List[Tuple[str, Any, int]]):
  13. template_path = os.path.join(dir, template_name)
  14. output_path = os.path.join(dir, output_name)
  15. with open(template_path, "r") as f:
  16. content = f.read()
  17. for placeholder, lines, indentation in replacements:
  18. with open(output_path, "w") as f:
  19. content = content.replace(placeholder, materialize_lines(lines, indentation))
  20. f.write(content)
  21. def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]:
  22. """
  23. When given a path to a directory, returns the paths to the relevant files within it.
  24. This function does NOT recursive traverse to subdirectories.
  25. """
  26. paths: Set[str] = set()
  27. for dir_path in dir_paths:
  28. all_files = os.listdir(dir_path)
  29. python_files = {fname for fname in all_files if ".py" == fname[-3:]}
  30. filter_files = {fname for fname in python_files if fname not in files_to_exclude}
  31. paths.update({os.path.join(dir_path, fname) for fname in filter_files})
  32. return paths
  33. def extract_method_name(line: str) -> str:
  34. """
  35. Extracts method name from decorator in the form of "@functional_datapipe({method_name})"
  36. """
  37. if "(\"" in line:
  38. start_token, end_token = "(\"", "\")"
  39. elif "(\'" in line:
  40. start_token, end_token = "(\'", "\')"
  41. else:
  42. raise RuntimeError(f"Unable to find appropriate method name within line:\n{line}")
  43. start, end = line.find(start_token) + len(start_token), line.find(end_token)
  44. return line[start:end]
  45. def extract_class_name(line: str) -> str:
  46. """
  47. Extracts class name from class definition in the form of "class {CLASS_NAME}({Type}):"
  48. """
  49. start_token = "class "
  50. end_token = "("
  51. start, end = line.find(start_token) + len(start_token), line.find(end_token)
  52. return line[start:end]
  53. def parse_datapipe_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str], Set[str]]:
  54. """
  55. Given a path to file, parses the file and returns a dictionary of method names to function signatures.
  56. """
  57. method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
  58. with open(file_path) as f:
  59. open_paren_count = 0
  60. method_name, class_name, signature = "", "", ""
  61. skip = False
  62. for line in f.readlines():
  63. if line.count("\"\"\"") % 2 == 1:
  64. skip = not skip
  65. if skip or "\"\"\"" in line: # Skipping comment/example blocks
  66. continue
  67. if "@functional_datapipe" in line:
  68. method_name = extract_method_name(line)
  69. continue
  70. if method_name and "class " in line:
  71. class_name = extract_class_name(line)
  72. continue
  73. if method_name and ("def __init__(" in line or "def __new__(" in line):
  74. if "def __new__(" in line:
  75. special_output_type.add(method_name)
  76. open_paren_count += 1
  77. start = line.find("(") + len("(")
  78. line = line[start:]
  79. if open_paren_count > 0:
  80. open_paren_count += line.count('(')
  81. open_paren_count -= line.count(')')
  82. if open_paren_count == 0:
  83. end = line.rfind(')')
  84. signature += line[:end]
  85. method_to_signature[method_name] = process_signature(signature)
  86. method_to_class_name[method_name] = class_name
  87. method_name, class_name, signature = "", "", ""
  88. elif open_paren_count < 0:
  89. raise RuntimeError("open parenthesis count < 0. This shouldn't be possible.")
  90. else:
  91. signature += line.strip('\n').strip(' ')
  92. return method_to_signature, method_to_class_name, special_output_type
  93. def parse_datapipe_files(file_paths: Set[str]) -> Tuple[Dict[str, str], Dict[str, str], Set[str]]:
  94. methods_and_signatures, methods_and_class_names, methods_with_special_output_types = {}, {}, set()
  95. for path in file_paths:
  96. method_to_signature, method_to_class_name, methods_needing_special_output_types = parse_datapipe_file(path)
  97. methods_and_signatures.update(method_to_signature)
  98. methods_and_class_names.update(method_to_class_name)
  99. methods_with_special_output_types.update(methods_needing_special_output_types)
  100. return methods_and_signatures, methods_and_class_names, methods_with_special_output_types
  101. def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]:
  102. """
  103. Given a line of text, split it on comma unless the comma is within a bracket '[]'.
  104. """
  105. bracket_count = 0
  106. curr_token = ""
  107. res = []
  108. for char in line:
  109. if char == "[":
  110. bracket_count += 1
  111. elif char == "]":
  112. bracket_count -= 1
  113. elif char == delimiter and bracket_count == 0:
  114. res.append(curr_token)
  115. curr_token = ""
  116. continue
  117. curr_token += char
  118. res.append(curr_token)
  119. return res
  120. def process_signature(line: str) -> str:
  121. """
  122. Given a raw function signature, clean it up by removing the self-referential datapipe argument,
  123. default arguments of input functions, newlines, and spaces.
  124. """
  125. tokens: List[str] = split_outside_bracket(line)
  126. for i, token in enumerate(tokens):
  127. tokens[i] = token.strip(' ')
  128. if token == "cls":
  129. tokens[i] = "self"
  130. elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
  131. # Remove the datapipe after 'self' or 'cls' unless it has '*'
  132. tokens[i] = ""
  133. elif "Callable =" in token: # Remove default argument if it is a function
  134. head, default_arg = token.rsplit("=", 2)
  135. tokens[i] = head.strip(' ') + "= ..."
  136. tokens = [t for t in tokens if t != ""]
  137. line = ', '.join(tokens)
  138. return line
  139. def get_method_definitions(file_path: Union[str, List[str]],
  140. files_to_exclude: Set[str],
  141. deprecated_files: Set[str],
  142. default_output_type: str,
  143. method_to_special_output_type: Dict[str, str],
  144. root: str = "") -> List[str]:
  145. """
  146. .pyi generation for functional DataPipes Process
  147. # 1. Find files that we want to process (exclude the ones who don't)
  148. # 2. Parse method name and signature
  149. # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
  150. """
  151. if root == "":
  152. root = str(pathlib.Path(__file__).parent.resolve())
  153. file_path = [file_path] if isinstance(file_path, str) else file_path
  154. file_path = [os.path.join(root, path) for path in file_path]
  155. file_paths = find_file_paths(file_path,
  156. files_to_exclude=files_to_exclude.union(deprecated_files))
  157. methods_and_signatures, methods_and_class_names, methods_w_special_output_types = \
  158. parse_datapipe_files(file_paths)
  159. for fn_name in method_to_special_output_type:
  160. if fn_name not in methods_w_special_output_types:
  161. methods_w_special_output_types.add(fn_name)
  162. method_definitions = []
  163. for method_name, arguments in methods_and_signatures.items():
  164. class_name = methods_and_class_names[method_name]
  165. if method_name in methods_w_special_output_types:
  166. output_type = method_to_special_output_type[method_name]
  167. else:
  168. output_type = default_output_type
  169. method_definitions.append(f"# Functional form of '{class_name}'\n"
  170. f"def {method_name}({arguments}) -> {output_type}: ...")
  171. method_definitions.sort(key=lambda s: s.split('\n')[1]) # sorting based on method_name
  172. return method_definitions
  173. # Defined outside of main() so they can be imported by TorchData
  174. iterDP_file_path: str = "iter"
  175. iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
  176. iterDP_deprecated_files: Set[str] = set()
  177. iterDP_method_to_special_output_type: Dict[str, str] = {"demux": "List[IterDataPipe]", "fork": "List[IterDataPipe]"}
  178. mapDP_file_path: str = "map"
  179. mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
  180. mapDP_deprecated_files: Set[str] = set()
  181. mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"}
  182. def main() -> None:
  183. """
  184. # Inject file into template datapipe.pyi.in
  185. TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
  186. interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
  187. """
  188. iter_method_definitions = get_method_definitions(iterDP_file_path, iterDP_files_to_exclude, iterDP_deprecated_files,
  189. "IterDataPipe", iterDP_method_to_special_output_type)
  190. map_method_definitions = get_method_definitions(mapDP_file_path, mapDP_files_to_exclude, mapDP_deprecated_files,
  191. "MapDataPipe", mapDP_method_to_special_output_type)
  192. path = pathlib.Path(__file__).parent.resolve()
  193. replacements = [('${IterDataPipeMethods}', iter_method_definitions, 4),
  194. ('${MapDataPipeMethods}', map_method_definitions, 4)]
  195. gen_from_template(dir=str(path),
  196. template_name="datapipe.pyi.in",
  197. output_name="datapipe.pyi",
  198. replacements=replacements)
  199. if __name__ == '__main__':
  200. main()