tvm.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import functools
  2. import importlib
  3. import logging
  4. import os
  5. import tempfile
  6. import torch
  7. from .common import device_from_inputs, fake_tensor_unsupported
  8. from .registry import register_backend
  9. log = logging.getLogger(__name__)
  10. @register_backend
  11. @fake_tensor_unsupported
  12. def tvm(gm, example_inputs, *, scheduler=None, trials=20000):
  13. import tvm # type: ignore[import]
  14. from tvm import relay # type: ignore[import]
  15. from tvm.contrib import graph_executor # type: ignore[import]
  16. jit_mod = torch.jit.trace(gm, example_inputs)
  17. device = device_from_inputs(example_inputs)
  18. shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
  19. mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
  20. if device.type == "cuda":
  21. dev = tvm.cuda(device.index)
  22. target = tvm.target.cuda()
  23. else:
  24. dev = tvm.cpu(0)
  25. target = tvm.target.Target(llvm_target())
  26. if scheduler is None:
  27. scheduler = os.environ.get("TVM_SCHEDULER", None)
  28. if scheduler == "auto_scheduler":
  29. from tvm import auto_scheduler
  30. log_file = tempfile.NamedTemporaryFile()
  31. if not os.path.exists(log_file):
  32. tasks, task_weights = auto_scheduler.extract_tasks(
  33. mod["main"], params, target
  34. )
  35. for task in tasks:
  36. print(task.compute_dag)
  37. else:
  38. print("No tasks")
  39. if len(tasks) != 0:
  40. tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
  41. if not os.path.exists(log_file):
  42. assert trials > 0
  43. tune_option = auto_scheduler.TuningOptions(
  44. num_measure_trials=trials,
  45. measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
  46. early_stopping=2000,
  47. )
  48. try:
  49. tuner.tune(tune_option)
  50. except Exception:
  51. if os.path.exists(log_file):
  52. os.unlink(log_file)
  53. raise
  54. with auto_scheduler.ApplyHistoryBest(log_file):
  55. with tvm.transform.PassContext(
  56. opt_level=3, config={"relay.backend.use_auto_scheduler": True}
  57. ):
  58. lib = relay.build(mod, target=target, params=params)
  59. elif scheduler == "meta_schedule":
  60. from tvm import meta_schedule as ms
  61. with tempfile.TemporaryDirectory() as work_dir:
  62. if device.type != "cuda":
  63. # meta_schedule needs num-cores to be specified
  64. # here we use the maximum core count
  65. target = tvm.target.Target(
  66. f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
  67. )
  68. # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
  69. # once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
  70. database = ms.relay_integration.tune_relay(
  71. mod=mod,
  72. target=target,
  73. work_dir=work_dir,
  74. max_trials_global=20000,
  75. num_trials_per_iter=64,
  76. params=params,
  77. strategy="evolutionary",
  78. )
  79. lib = ms.relay_integration.compile_relay(
  80. database=database,
  81. mod=mod,
  82. target=target,
  83. params=params,
  84. )
  85. elif scheduler == "default" or not scheduler:
  86. # no autotuning
  87. with tvm.transform.PassContext(opt_level=10):
  88. lib = relay.build(mod, target=target, params=params)
  89. else:
  90. raise NotImplementedError(
  91. "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
  92. "There are three available options: default, auto_scheduler and meta_schedule."
  93. )
  94. m = graph_executor.GraphModule(lib["default"](dev))
  95. def to_torch_tensor(nd_tensor):
  96. """A helper function to transfer a NDArray to torch.tensor."""
  97. if nd_tensor.dtype == "bool":
  98. # DLPack does not support boolean so it can't be handled by
  99. # torch.utils.dlpack.from_pack. Workaround by going through
  100. # numpy, although this brings additional data copy overhead.
  101. return torch.from_numpy(nd_tensor.numpy())
  102. return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
  103. def to_tvm_tensor(torch_tensor):
  104. """A helper function to transfer a torch.tensor to NDArray."""
  105. if torch_tensor.dtype == torch.bool:
  106. # same reason as above, fallback to numpy conversion which
  107. # could introduce data copy overhead
  108. return tvm.nd.array(torch_tensor.cpu().numpy())
  109. return tvm.nd.from_dlpack(torch_tensor)
  110. def exec_tvm(*i_args):
  111. args = [a.contiguous() for a in i_args]
  112. for idx, arg in enumerate(args, 0):
  113. if arg.dim() != 0:
  114. if arg.requires_grad:
  115. arg = arg.detach()
  116. m.set_input(
  117. f"inp_{idx}",
  118. to_tvm_tensor(arg),
  119. )
  120. m.run()
  121. return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
  122. return exec_tvm
  123. tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
  124. tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
  125. def has_tvm():
  126. try:
  127. importlib.import_module("tvm")
  128. return True
  129. except ImportError:
  130. return False
  131. @functools.lru_cache(None)
  132. def llvm_target():
  133. if "avx512" in open("/proc/cpuinfo").read():
  134. return "llvm -mcpu=skylake-avx512"
  135. return "llvm -mcpu=core-avx2"