gen_mobile_upgraders.py 12 KB


  1. #!/usr/bin/env python3
  2. import os
  3. from enum import Enum
  4. from pathlib import Path
  5. from typing import Any, Dict, List
  6. import torch
  7. from torch.jit.generate_bytecode import generate_upgraders_bytecode
  8. from torchgen.code_template import CodeTemplate
  9. from torchgen.operator_versions.gen_mobile_upgraders_constant import (
  10. MOBILE_UPGRADERS_HEADER_DESCRIPTION,
  11. )
  12. class ByteCode(Enum):
  13. instructions = 1
  14. constants = 2
  15. types = 3
  16. operators = 4
  17. register_size = 5
  18. EXCLUDED_OP_SET = [
  19. "aten::full.names",
  20. "aten::full.out",
  21. "aten::full",
  22. ]
  23. EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"]
  24. ONE_INSTRUCTION = CodeTemplate(
  25. """
  26. Instruction{OpCode::${operator_name}, ${X}, ${N}},"""
  27. )
  28. INSTRUCTION_LIST = CodeTemplate(
  29. """std::vector<Instruction>({
  30. ${instruction_list}
  31. }), // instructions list"""
  32. )
  33. ONE_CONSTANT = CodeTemplate(
  34. """
  35. c10::IValue(${constant}),"""
  36. )
  37. CONSTANT_LIST = CodeTemplate(
  38. """std::vector<c10::IValue>({
  39. ${constant_list}
  40. }), // constants list"""
  41. )
  42. CONSTANTS_LIST_EMPTY = """std::vector<c10::IValue>(), // constants list"""
  43. ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""")
  44. TYPE_LIST = CodeTemplate(
  45. """std::vector<c10::TypePtr>({
  46. ${type_list}
  47. }), // types list"""
  48. )
  49. TYPE_LIST_EMPTY = """std::vector<c10::TypePtr>(), // types list"""
  50. ONE_OPERATOTR_STRING = CodeTemplate(
  51. """
  52. OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),"""
  53. )
  54. OPERATOR_STRING_LIST = CodeTemplate(
  55. """
  56. std::vector<OperatorString>({
  57. ${operator_string_list}
  58. }), // operators list"""
  59. )
  60. ONE_UPGRADER_FUNCTION = CodeTemplate(
  61. """
  62. mobile::Function::registerFunc(
  63. "${upgrader_name}",
  64. ${instruction_list},
  65. ${constant_list},
  66. ${type_list},
  67. ${register_size}
  68. )"""
  69. )
  70. ONE_UPGRADER_SRC = CodeTemplate(
  71. """
  72. ByteCodeFunctionWithOperator({
  73. ${bytecode_function},
  74. ${operator_string_list}
  75. }),"""
  76. )
  77. ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate(
  78. """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})"""
  79. ) # noqa: E501
  80. ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate(
  81. """
  82. {std::string("${operator_name}"),
  83. std::vector<Upgrader>({
  84. ${upgrader_list_in_version_map}
  85. })},"""
  86. )
  87. OPERATOR_VERSION_MAP = CodeTemplate(
  88. """
  89. const std::unordered_map<std::string, std::vector<Upgrader>>
  90. getOperatorVersionMapForMobile() {
  91. static std::unordered_map<std::string, std::vector<Upgrader>>
  92. operatorVersionMapForMobile({
  93. ${operator_list_in_version_map}
  94. });
  95. return operatorVersionMapForMobile;
  96. }
  97. """
  98. )
  99. UPGRADER_CPP_SRC = CodeTemplate(
  100. MOBILE_UPGRADERS_HEADER_DESCRIPTION
  101. + """
  102. #include <caffe2/serialize/versions.h>
  103. #include <torch/csrc/jit/mobile/upgrader_mobile.h>
  104. namespace c10 {
  105. TypePtr parseType(const std::string& pythonStr);
  106. } // namespace c10
  107. namespace torch {
  108. namespace jit {
  109. // clang-format off
  110. // From operator_versions_map
  111. ${operator_version_map}
  112. const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
  113. auto generate_upgrader_bytecode_list = []() {
  114. std::vector<ByteCodeFunctionWithOperator> upgrader_function_list({
  115. ${upgrader_bytecode}
  116. });
  117. for (const auto& upgrader_function : upgrader_function_list) {
  118. for (const auto& op : upgrader_function.operators) {
  119. upgrader_function.function.append_operator(
  120. op.name,
  121. op.overload_name,
  122. op.num_specified_args);
  123. }
  124. }
  125. return upgrader_function_list;
  126. };
  127. static std::vector<ByteCodeFunctionWithOperator> upgraderBytecodeList =
  128. generate_upgrader_bytecode_list();
  129. return upgraderBytecodeList;
  130. }
  131. // clang-format on
  132. } // namespace jit
  133. } // namespace torch
  134. """
  135. )
  136. UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp"
  137. UPGRADER_ELEMENT = CodeTemplate(
  138. """\
  139. Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}),
  140. """
  141. )
  142. PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
  143. """\
  144. {
  145. std::string(${operator_name}),
  146. std::vector<Upgrader>({${upgrader_list}});
  147. }
  148. """
  149. )
  150. def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
  151. instruction_list_part = []
  152. for instruction in instruction_list_from_yaml:
  153. instruction_list_part.append(
  154. ONE_INSTRUCTION.substitute(
  155. operator_name=instruction[0],
  156. X=instruction[1],
  157. N=instruction[2],
  158. )
  159. )
  160. return INSTRUCTION_LIST.substitute(
  161. instruction_list="".join(instruction_list_part).lstrip("\n")
  162. )
  163. def construct_constants(constants_list_from_yaml: List[Any]) -> str:
  164. constants_list_part = []
  165. for constant_from_yaml in constants_list_from_yaml:
  166. convert_constant = None
  167. if isinstance(constant_from_yaml, str):
  168. # Add quotes if it's string
  169. convert_constant = f'"{constant_from_yaml}"'
  170. elif isinstance(constant_from_yaml, bool):
  171. convert_constant = "true" if constant_from_yaml else "false"
  172. elif constant_from_yaml is None:
  173. convert_constant = ""
  174. elif isinstance(constant_from_yaml, int):
  175. convert_constant = str(constant_from_yaml)
  176. else:
  177. raise ValueError(
  178. f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. "
  179. "Please add change in construct_constants function in gen_mobile_upgraders.py."
  180. )
  181. constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant))
  182. if len(constants_list_part) == 0:
  183. return CONSTANTS_LIST_EMPTY
  184. return CONSTANT_LIST.substitute(
  185. constant_list="".join(constants_list_part).lstrip("\n")
  186. )
  187. def construct_operators(operator_list_from_yaml: List[Any]) -> str:
  188. operator_list_part = []
  189. for operator in operator_list_from_yaml:
  190. operator_list_part.append(
  191. ONE_OPERATOTR_STRING.substitute(
  192. operator_name=operator[0],
  193. overload_name=operator[1],
  194. num_of_args=operator[2],
  195. )
  196. )
  197. return OPERATOR_STRING_LIST.substitute(
  198. operator_string_list="".join(operator_list_part).lstrip("\n")
  199. )
  200. def construct_types(types_tr_list_from_yaml: List[Any]) -> str:
  201. types_tr_list_part = []
  202. for types_tr in types_tr_list_from_yaml:
  203. types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
  204. if len(types_tr_list_part) == 0:
  205. return TYPE_LIST_EMPTY
  206. return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n"))
  207. def construct_register_size(register_size_from_yaml: int) -> str:
  208. if not isinstance(register_size_from_yaml, int):
  209. raise ValueError(
  210. f"Input register size is {register_size_from_yaml} and"
  211. "it's type is {type(register_size_from_yaml)}. An int type is expected."
  212. )
  213. return str(register_size_from_yaml)
  214. def construct_version_maps(
  215. upgrader_bytecode_function_to_index_map: Dict[str, Any]
  216. ) -> str:
  217. version_map = torch._C._get_operator_version_map()
  218. sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return]
  219. sorted_version_map = {name: lst for name, lst in sorted_version_map_}
  220. operator_list_in_version_map_part = []
  221. for op_name in sorted_version_map:
  222. upgraders_in_version_map_part = []
  223. # TODO: remove the skip after these two operators schemas are fixed
  224. if op_name in EXCLUDED_OP_SET:
  225. continue
  226. upgrader_ranges = torch._C._get_upgrader_ranges(op_name)
  227. upgrader_entries = sorted_version_map[op_name]
  228. assert len(upgrader_ranges) == len(upgrader_entries)
  229. for idx, upgrader_entry in enumerate(upgrader_entries):
  230. upgrader_name = upgrader_entry.upgrader_name
  231. bytecode_function_index = upgrader_bytecode_function_to_index_map[
  232. upgrader_name
  233. ]
  234. upgraders_in_version_map_part.append(
  235. ONE_UPGRADER_IN_VERSION_MAP.substitute(
  236. upgrader_min_version=upgrader_ranges[idx].min_version,
  237. upgrader_max_version=upgrader_ranges[idx].max_version,
  238. upgrader_name=upgrader_name,
  239. bytecode_func_index=bytecode_function_index,
  240. )
  241. )
  242. operator_list_in_version_map_part.append(
  243. ONE_OPERATOR_IN_VERSION_MAP.substitute(
  244. operator_name=op_name,
  245. upgrader_list_in_version_map="".join(upgraders_in_version_map_part),
  246. )
  247. )
  248. return OPERATOR_VERSION_MAP.substitute(
  249. operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip(
  250. "\n"
  251. )
  252. )
  253. def get_upgrader_bytecode_function_to_index_map(
  254. upgrader_dict: List[Dict[str, Any]]
  255. ) -> Dict[str, Any]:
  256. upgrader_bytecode_function_to_index_map = {}
  257. index = 0
  258. for upgrader_bytecode in upgrader_dict:
  259. for upgrader_name, bytecode in upgrader_bytecode.items():
  260. if upgrader_name in EXCLUE_UPGRADER_SET:
  261. continue
  262. upgrader_bytecode_function_to_index_map[upgrader_name] = index
  263. index += 1
  264. return upgrader_bytecode_function_to_index_map
  265. def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
  266. body_parts = []
  267. upgrader_bytecode_function_to_index_map = (
  268. get_upgrader_bytecode_function_to_index_map(upgrader_dict)
  269. )
  270. version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map)
  271. all_upgrader_src_string = []
  272. for upgrader_bytecode in upgrader_dict:
  273. for upgrader_name, bytecode in upgrader_bytecode.items():
  274. # TODO: remove the skip after these two operators schemas are fixed
  275. if upgrader_name in EXCLUE_UPGRADER_SET:
  276. continue
  277. instruction_list_str = ""
  278. constant_list_str = ""
  279. type_list_str = ""
  280. register_size_str = ""
  281. operator_list_str = ""
  282. for table_name, contents in bytecode.items():
  283. element = ByteCode[table_name]
  284. body_string = ""
  285. if element is ByteCode.instructions:
  286. instruction_list_str = construct_instruction(contents)
  287. elif element is ByteCode.constants:
  288. constant_list_str = construct_constants(contents)
  289. elif element is ByteCode.operators:
  290. operator_list_str = construct_operators(contents)
  291. elif element is ByteCode.types:
  292. type_list_str = construct_types(contents)
  293. elif element is ByteCode.register_size:
  294. register_size_str = construct_register_size(contents)
  295. one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute(
  296. upgrader_name=upgrader_name,
  297. instruction_list=instruction_list_str,
  298. constant_list=constant_list_str,
  299. type_list=type_list_str,
  300. register_size=register_size_str,
  301. )
  302. one_upgrader_src_string = ONE_UPGRADER_SRC.substitute(
  303. bytecode_function=one_upgrader_function_string.lstrip("\n"),
  304. operator_string_list=operator_list_str.lstrip("\n"),
  305. )
  306. all_upgrader_src_string.append(one_upgrader_src_string)
  307. upgrader_file_content = UPGRADER_CPP_SRC.substitute(
  308. operator_version_map=version_map_src,
  309. upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"),
  310. )
  311. body_parts.append(upgrader_file_content)
  312. print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME)
  313. with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file:
  314. final_output = "".join(body_parts)
  315. out_file.write(upgrader_file_content.encode("utf-8"))
  316. def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  317. sorted_upgrader_list = sorted(
  318. upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
  319. )
  320. return sorted_upgrader_list
  321. def main() -> None:
  322. upgrader_list = generate_upgraders_bytecode()
  323. sorted_upgrader_list = sort_upgrader(upgrader_list)
  324. for up in sorted_upgrader_list:
  325. print("after sort upgrader : ", next(iter(up)))
  326. pytorch_dir = Path(__file__).resolve().parents[2]
  327. upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile"
  328. write_cpp(str(upgrader_path), sorted_upgrader_list)
  329. if __name__ == "__main__":
  330. main()