instantiator.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #!/usr/bin/python3
  2. import importlib
  3. import logging
  4. import os
  5. import sys
  6. import tempfile
  7. from typing import Optional
  8. import torch
  9. from torch.distributed.nn.jit.templates.remote_module_template import (
  10. get_remote_module_template,
  11. )
  12. logger = logging.getLogger(__name__)
  13. _FILE_PREFIX = "_remote_module_"
  14. _TEMP_DIR = tempfile.TemporaryDirectory()
  15. INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name
  16. logger.info(f"Created a temporary directory at {INSTANTIATED_TEMPLATE_DIR_PATH}")
  17. sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
  18. def get_arg_return_types_from_interface(module_interface):
  19. assert getattr(
  20. module_interface, "__torch_script_interface__", False
  21. ), "Expect a TorchScript class interface decorated by @torch.jit.interface."
  22. qualified_name = torch._jit_internal._qualified_name(module_interface)
  23. cu = torch.jit._state._python_cu
  24. module_interface_c = cu.get_interface(qualified_name)
  25. assert (
  26. "forward" in module_interface_c.getMethodNames()
  27. ), "Expect forward in interface methods, while it has {}".format(
  28. module_interface_c.getMethodNames()
  29. )
  30. method_schema = module_interface_c.getMethod("forward")
  31. arg_str_list = []
  32. arg_type_str_list = []
  33. assert method_schema is not None
  34. for argument in method_schema.arguments:
  35. arg_str_list.append(argument.name)
  36. if argument.has_default_value():
  37. default_value_str = " = {}".format(argument.default_value)
  38. else:
  39. default_value_str = ""
  40. arg_type_str = "{name}: {type}{default_value}".format(
  41. name=argument.name, type=argument.type, default_value=default_value_str
  42. )
  43. arg_type_str_list.append(arg_type_str)
  44. arg_str_list = arg_str_list[1:] # Remove "self".
  45. args_str = ", ".join(arg_str_list)
  46. arg_type_str_list = arg_type_str_list[1:] # Remove "self".
  47. arg_types_str = ", ".join(arg_type_str_list)
  48. assert len(method_schema.returns) == 1
  49. argument = method_schema.returns[0]
  50. return_type_str = str(argument.type)
  51. return args_str, arg_types_str, return_type_str
  52. def _write(out_path, text):
  53. old_text: Optional[str]
  54. try:
  55. with open(out_path, "r") as f:
  56. old_text = f.read()
  57. except IOError:
  58. old_text = None
  59. if old_text != text:
  60. with open(out_path, "w") as f:
  61. logger.info("Writing {}".format(out_path))
  62. f.write(text)
  63. else:
  64. logger.info("Skipped writing {}".format(out_path))
  65. def _do_instantiate_remote_module_template(
  66. generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
  67. ):
  68. generated_code_text = get_remote_module_template(
  69. enable_moving_cpu_tensors_to_cuda
  70. ).format(**str_dict)
  71. out_path = os.path.join(
  72. INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
  73. )
  74. _write(out_path, generated_code_text)
  75. # From importlib doc,
  76. # > If you are dynamically importing a module that was created since
  77. # the interpreter began execution (e.g., created a Python source file),
  78. # you may need to call invalidate_caches() in order for the new module
  79. # to be noticed by the import system.
  80. importlib.invalidate_caches()
  81. generated_module = importlib.import_module(f"{generated_module_name}")
  82. return generated_module
  83. def instantiate_scriptable_remote_module_template(
  84. module_interface_cls, enable_moving_cpu_tensors_to_cuda=True
  85. ):
  86. if not getattr(module_interface_cls, "__torch_script_interface__", False):
  87. raise ValueError(
  88. f"module_interface_cls {module_interface_cls} must be a type object decorated by "
  89. "@torch.jit.interface"
  90. )
  91. # Generate the template instance name.
  92. module_interface_cls_name = torch._jit_internal._qualified_name(
  93. module_interface_cls
  94. ).replace(".", "_")
  95. generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}"
  96. # Generate type annotation strs.
  97. assign_module_interface_cls_str = (
  98. f"from {module_interface_cls.__module__} import "
  99. f"{module_interface_cls.__name__} as module_interface_cls"
  100. )
  101. args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface(
  102. module_interface_cls
  103. )
  104. kwargs_str = ""
  105. arrow_and_return_type_str = f" -> {return_type_str}"
  106. arrow_and_future_return_type_str = f" -> Future[{return_type_str}]"
  107. str_dict = dict(
  108. assign_module_interface_cls=assign_module_interface_cls_str,
  109. arg_types=arg_types_str,
  110. arrow_and_return_type=arrow_and_return_type_str,
  111. arrow_and_future_return_type=arrow_and_future_return_type_str,
  112. args=args_str,
  113. kwargs=kwargs_str,
  114. jit_script_decorator="@torch.jit.script",
  115. )
  116. return _do_instantiate_remote_module_template(
  117. generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
  118. )
  119. def instantiate_non_scriptable_remote_module_template():
  120. generated_module_name = f"{_FILE_PREFIX}non_scriptable"
  121. str_dict = dict(
  122. assign_module_interface_cls="module_interface_cls = None",
  123. args="*args",
  124. kwargs="**kwargs",
  125. arg_types="*args, **kwargs",
  126. arrow_and_return_type="",
  127. arrow_and_future_return_type="",
  128. jit_script_decorator="",
  129. )
  130. # For a non-scriptable template, always enable moving CPU tensors to a cuda device,
  131. # because there is no syntax limitation on the extra handling caused by the script.
  132. return _do_instantiate_remote_module_template(generated_module_name, str_dict, True)