generate_bytecode.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph
  2. from typing import List
  3. def format_bytecode(table):
  4. # given a nested tuple, convert it to nested list
  5. def listify(content):
  6. if not isinstance(content, tuple):
  7. return content
  8. return [listify(i) for i in content]
  9. formatted_table = {}
  10. for entry in table:
  11. identifier = entry[0]
  12. content = entry[1]
  13. content = listify(content)
  14. formatted_table[identifier] = content
  15. return formatted_table
  16. def generate_upgraders_bytecode() -> List:
  17. yaml_content = []
  18. upgraders_graph_map = _generate_upgraders_graph()
  19. for upgrader_name, upgrader_graph in upgraders_graph_map.items():
  20. bytecode_table = _compile_graph_to_code_table(upgrader_name, upgrader_graph)
  21. entry = {upgrader_name: format_bytecode(bytecode_table)}
  22. yaml_content.append(entry)
  23. return yaml_content
  24. if __name__ == "__main__":
  25. raise RuntimeError("This file is not meant to be run directly")