import threading import torch._C._lazy from torch.utils._pytree import tree_flatten, tree_unflatten from .closure import add_step_closure, run_step_closures def mark_step(device: str = "", wait=False): """Triggers a mark step, which amounts to - collecting a group of 'live' lazy tensors to index into the compilation cache (lowering/compiling their IR graphs if not cached) - kicking off execution of the compiled function - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) """ # TODO(whc) expand this to include backend hooks and align with XLA backend needs torch._C._lazy._mark_step(device, [], wait=wait) run_step_closures() def wait_device_ops(devices=None): """Waits for all the async operations on the given devices to complete. Args: devices (string..., optional): The devices whose async ops need to be waited for. If empty, all the local devices will be waited for. """ if devices is None: devices = [] torch._C._lazy._wait_device_ops(devices=devices) def sync_multi(tensors, devices): """ Sync the list of lazy tensors so there IR get lowered for the activate backend and the compiled computation graph get cached. """ torch._C._lazy._sync_multi(tensors, devices) def get_tensor_id(tensor): """Return a unique id of the lazy tensor maintained by LTC""" return torch._C._lazy._get_tensor_id(tensor) def to_cpu(tensors, devices=None): devices = devices or ["lazy"] flattened, spec = tree_flatten(tensors) sync_multi(flattened, devices) return tree_unflatten([t.to("cpu") for t in flattened], spec) def save(tensors, *args, **kwargs): torch.save(to_cpu(tensors), *args, **kwargs)