torchxla.py 1008 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. import logging
  2. from ..backends.common import aot_autograd
  3. from ..backends.registry import register_experimental_backend as register_backend
  4. log = logging.getLogger(__name__)
  5. @register_backend
  6. def torchxla_trivial(gm, fake_tensor_inputs):
  7. return gm
  8. @register_backend
  9. def torchxla_trace_once(model, fake_tensor_inputs):
  10. import torch_xla.core.dynamo_bridge as bridge # type: ignore[import]
  11. compiled_graph = None
  12. def fwd(*args):
  13. nonlocal model
  14. nonlocal compiled_graph
  15. if compiled_graph is None:
  16. compiled_graph = bridge.extract_compiled_graph(model, args)
  17. del model
  18. return compiled_graph(*args)
  19. return fwd
  20. aot_torchxla_trivial = aot_autograd(
  21. fw_compiler=torchxla_trivial,
  22. )
  23. register_backend(name="aot_torchxla_trivial", compiler_fn=aot_torchxla_trivial)
  24. aot_torchxla_trace_once = aot_autograd(
  25. fw_compiler=torchxla_trace_once,
  26. )
  27. register_backend(name="aot_torchxla_trace_once", compiler_fn=aot_torchxla_trace_once)