123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- import contextlib
- import warnings
- import weakref
- from typing import ContextManager, Optional
- import torch
- from torch._guards import Source
- from torch.multiprocessing.reductions import StorageWeakRef
- from torch.utils.weak import WeakIdRef
- def safe_is_leaf(t):
- try:
- return t.is_leaf
- except RuntimeError:
- # inference mode can trigger this
- return False
- def safe_grad(t):
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
- return t.grad
- def assert_eq(a, b):
- assert a == b, f"{a} != {b}"
- def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False):
- def go(m1, m2):
- assert_eq(m1.dtype, m2.dtype)
- if not skip_symbolic:
- assert_eq(m1.shape, m2.shape)
- assert_eq(m1.requires_grad, m2.requires_grad)
- assert_eq(m1.is_leaf, m2.is_leaf)
- assert_eq(m1.grad_fn is None, m2.grad_fn is None)
- assert_eq(m1.is_sparse, m2.is_sparse)
- assert_eq(m1.is_inference(), m2.is_inference())
- assert_eq(m1.is_conj(), m2.is_conj())
- assert_eq(m1.is_neg(), m2.is_neg())
- assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None)
- if safe_grad(m1) is not None:
- go(safe_grad(m1), safe_grad(m2))
- if m1.is_sparse:
- assert_eq(m1.dense_dim(), m2.dense_dim())
- assert_eq(m1.sparse_dim(), m2.sparse_dim())
- assert_eq(m1.is_coalesced(), m2.is_coalesced())
- else:
- if not skip_symbolic:
- assert_eq(m1.stride(), m2.stride())
- assert_eq(m1.storage_offset(), m2.storage_offset())
- assert_eq(m1._is_view(), m2._is_view())
- if m1._is_view():
- go(m1._base, m2._base)
- # TODO: test if is resizable (no direct query for this atm)
- # TODO: audit AutogradMeta to see if it matches
- # TODO: test forward AD
- return go(m1, m2)
- # This is a class for converting multiple tensors into meta tensors which
- # share the same view/storage structure. The operation model is you allocate
- # one of these, and then call it repeatedly on all the tensors you want to
- # convert. It's important to use the same object for tensors you want to
- # share storage because this is how we correlate shared storages to the same
- # meta storages. This class will hold weak references to cached tenosrs
- # and tensor storages.
- class MetaConverter:
- def __init__(self):
- self.storage_memo = {}
- self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
- self.maybe_storages_to_delete = []
- self.check_expired_frequency = 128
- self.check_expired_count = 0
- self.hit = 0
- self.miss = 0
- self.del_hook = None
- self.arg_cnt = 0
- def successful(self):
- return self.hit > 0 and self.miss == 0
- def check_for_expired_weak_storages(self):
- new_li = []
- stor_to_delete = []
- for obj in self.maybe_storages_to_delete:
- if not obj.expired():
- new_li.append(obj)
- else:
- stor_to_delete.append(obj)
- for obj in stor_to_delete:
- self.storage_memo.pop(obj, None)
- self.maybe_storages_to_delete = new_li
- # if for some reason we have aquired many storages which have not expired
- # even though a tensor with their storage has expired (aliasing or otherwise)
- # check for expired storages less often so as to bound the amount of work we
- # do checking for expired storages
- self.check_expired_frequency = max(
- self.check_expired_frequency, len(self.maybe_storages_to_delete)
- )
- def get_tensor_memo(self, t):
- return self.tensor_memo.get(WeakIdRef(t), None)
- def set_tensor_memo(self, t, v):
- # hold a weak ref to self, otherwise it will be kept alive
- # by the del_ten closure
- self_weak_ref = weakref.ref(self)
- if t.is_sparse or t.is_mkldnn:
- weak_st = None
- else:
- weak_st = StorageWeakRef(t._typed_storage())
- tensor_ref_key = WeakIdRef(t)
- def del_ten():
- # tensor outlives the converter
- self_ref = self_weak_ref()
- if self_ref is None:
- return
- # on shutdown, tensor_ref_key may not be in memo
- self_ref.tensor_memo.pop(tensor_ref_key, None)
- if weak_st and weak_st.expired():
- self_ref.storage_memo.pop(weak_st, None)
- elif weak_st is not None:
- # [expired-storages]
- # NB: even though the tensor has died,
- # the deallocation of its storage can take longer,
- # even when the storage has no other uses/views.
- # In this case, the StorageWeakRef object will be kept alive
- # longer than it needs to be, however the storage itself
- # will be deallocated. We retain the possibly dead storages
- # and periodically check if any of them are expired and
- # can be freed.
- self_ref.maybe_storages_to_delete.append(weak_st)
- weakref.finalize(t, del_ten)
- self.tensor_memo[tensor_ref_key] = v
- # NB: doesn't actually return a storage, because meta storage is
- # not supported
- def meta_storage(self, s, callback):
- # NB: TypedStorage is freshly allocated and cannot be used as hash
- # key index.
- # Use a Weak Ref to s in order to not leak memory
- swr = StorageWeakRef(s)
- if swr not in self.storage_memo:
- self.storage_memo[swr] = callback(
- lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
- ).untyped_storage()
- return self.storage_memo[swr]
- # This function assumes that it's possible to do the conversion
- # NB: name here is used in a conventional way by Dynamo; it corresponds
- # precisely to the Source.name() of the tensor we're fakeifying and
- # corresponds to a valid Python expression. When we construct sub-names
- # as part of this process, we will maintain this invariant! (Even though
- # other users of this may not need it this property to be upheld.)
- def meta_tensor(
- self, t, shape_env=None, callback=lambda t: t(), source: Optional[Source] = None
- ):
- if source is None:
- from torch._dynamo.source import ConstantSource
- # TODO: make a dedicated UnknownSource for this?
- source = ConstantSource(f"__unknown_tensor{len(self.tensor_memo)}")
- # This indicates you set no_dispatch() before calling into this
- # function. This is an error: we may be creating fake tensors and
- # will perform operations on them which need fake tensor mode to
- # be active. You will segfault if you are in a no_dispatch() block.
- assert not torch._C._dispatch_tls_local_exclude_set().has(
- torch._C.DispatchKey.Python
- )
- arg_cnt = self.arg_cnt
- self.arg_cnt += 1
- # When we make as_strided calls, we end up generating a guard
- # that the new as_strided tensor is in bounds for the old storage
- # for the base (since as_strided calls can "bust" out of their
- # bounding box.) This guard is unnecessary: if a user is able
- # to provide us a tensor with the view base setup this way, we
- # don't need to produce a guard, because the fact that they
- # were able to produce the view base means its in bounds.
- #
- # Now, ordinarily, this guard would be harmless. However, the
- # generated guard refers to variables bound on the base variable.
- # At the moment, Dynamo doesn't actually guard on x._base, because
- # according to Voz this results in a lot of spurious invalidations,
- # and also if the user doesn't directly make use of _base, its
- # pointless anyway (because programs should be parametric over
- # whether or not the input tensor is a view or not--unless you're
- # mutating the input, but that's a whole 'nother ballgame). So
- # for expediency, we suppress these guards so we don't have to
- # deal with this (yet, anyway.)
- #
- # NB: An old version of this code suppressed guards for ALL operations
- # happening during meta conversion, not just as_strided calls.
- # This is too aggressive: we do duck sizing and 0/1 simplification
- # as we allocate variables, and we do need to register guards for
- # these cases.
- maybe_suppress = contextlib.nullcontext
- if shape_env is not None:
- maybe_suppress = shape_env.suppress_guards
- make_symbolic = shape_env is not None
- def sym_sizes_strides_storage_offset(t):
- if make_symbolic:
- return shape_env.create_symbolic_sizes_strides_storage_offset(t, source)
- return (t.size(), t.stride(), t.storage_offset())
- # see expired-storages
- self.check_expired_count += 1
- if self.check_expired_count >= self.check_expired_frequency:
- self.check_for_expired_weak_storages()
- self.check_expired_count = 0
- if self.get_tensor_memo(t) is None:
- with torch.inference_mode(t.is_inference()):
- if t.is_sparse:
- assert shape_env is None, "symbolic on sparse NYI"
- is_leaf = safe_is_leaf(t)
- r = callback(
- lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
- t.sparse_dim(),
- t.dense_dim(),
- t.shape,
- dtype=t.dtype,
- layout=torch.sparse_coo,
- device="meta",
- )
- )
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- # Note [is_coalesced is dispatched]
- # Strangely enough, is_coalesced() is a dispatched operator,
- # which means that it will get caught by fake tensor mode.
- # Ordinarily this would error, but there's some logic in
- # fake tensor ensure this doesn't happen.
- r._coalesced_(t.is_coalesced())
- if t.requires_grad:
- r.requires_grad = True
- if t.requires_grad and not is_leaf:
- with torch.enable_grad():
- r = r.clone()
- r._coalesced_(t.is_coalesced())
- elif t.is_mkldnn:
- is_leaf = safe_is_leaf(t)
- sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
- t
- )
- r = callback(
- lambda: torch.empty_strided(
- sizes, strides, dtype=t.dtype, device="meta"
- )
- )
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- if t.requires_grad:
- r.requires_grad = True
- if t.requires_grad and not is_leaf:
- with torch.enable_grad():
- r = r.clone()
- elif t._is_view():
- # Construct views in two steps: recursively meta-fy their
- # base, and then create view(s) off that. NB: doing it
- # directly from storage is WRONG because this won't cause
- # version counters to get shared.
- assert t._is_view()
- from torch._dynamo.source import AttrSource
- base = self.meta_tensor(
- t._base, shape_env, callback, source=AttrSource(source, "_base")
- )
- def is_c_of_r(complex_dtype, real_dtype):
- return (
- utils.is_complex_dtype(complex_dtype)
- and utils.corresponding_real_dtype(complex_dtype)
- == real_dtype
- )
- # In some situations, MetaConverter may be called in a
- # context where autograd is disabled. For the _is_view
- # assert to pass, we have to setup the autograd view
- # metadata anyway. Do this by reenabling the
- # ADInplaceOrView key. This is kind of a hack.
- old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
- torch._C.DispatchKey.ADInplaceOrView
- )
- torch._C._dispatch_tls_set_dispatch_key_excluded(
- torch._C.DispatchKey.ADInplaceOrView, False
- )
- try:
- if base.dtype == t.dtype:
- pass
- elif is_c_of_r(base.dtype, t.dtype):
- base = torch.view_as_real(base)
- elif is_c_of_r(t.dtype, base.dtype):
- base = torch.view_as_complex(base)
- else:
- # This is not guaranteed to succeed. If it fails, it
- # means there is another dtype-converting view function
- # that hasn't been handled here
- base = base.view(t.dtype)
- # This is very tricky. Naively, you might expect this
- # to hold:
- #
- # if t.requires_grad and not safe_is_leaf(t)
- # assert t._base.requires_grad
- #
- # But it's not true! As you can see in the following
- # program:
- #
- # x = torch.zeros(4)
- # y = x.view(1, 4)
- # y.requires_grad = True
- # z = y.view(1, 1, 4)
- # assert z._base is x
- #
- # So we may have to do *two* views out of the base to
- # recreate this situation.
- (
- sizes,
- strides,
- storage_offset,
- ) = sym_sizes_strides_storage_offset(t)
- if safe_is_leaf(t):
- # Leaf views that track view metadata are created by
- # creating a view inside a no_grad block
- with torch.no_grad(), maybe_suppress():
- r = base.as_strided(sizes, strides, storage_offset)
- # As it's a leaf, we can directly assign requires_grad
- r.requires_grad = t.requires_grad
- else:
- if t._base.requires_grad == t.requires_grad:
- # Easy case, just run the view op
- with torch.enable_grad(), maybe_suppress():
- r = base.as_strided(sizes, strides, storage_offset)
- else:
- # Obscure case. Create a leaf view and give it the
- # correct requires_grad, then do the final view.
- # NB: Can't have a non-leaf without requiring grad!
- assert t.requires_grad
- with torch.no_grad():
- mid = base.view(base.shape)
- mid.requires_grad = t.requires_grad
- with torch.enable_grad(), maybe_suppress():
- r = mid.as_strided(sizes, strides, storage_offset)
- finally:
- torch._C._dispatch_tls_set_dispatch_key_excluded(
- torch._C.DispatchKey.ADInplaceOrView, old_exclude
- )
- else:
- is_leaf = safe_is_leaf(t)
- sizes, strides, storage_offset = sym_sizes_strides_storage_offset(t)
- r = callback(
- lambda: torch.empty_strided(
- sizes, strides, dtype=t.dtype, device="meta"
- )
- )
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- if t.requires_grad:
- r.requires_grad = t.requires_grad
- if not is_leaf:
- # Fake up some autograd history.
- with torch.enable_grad():
- # preserve_format is the default, but we want to
- # emphasize how important it is to preserve
- # format here
- r = r.clone(memory_format=torch.preserve_format)
- s = t.untyped_storage()
- swr = StorageWeakRef(s)
- if (
- swr not in self.storage_memo
- and r.stride() == strides
- and r.storage_offset() == storage_offset
- ):
- # You're normal and happy, install the fresh storage into the memo
- self.storage_memo[swr] = r.untyped_storage()
- else:
- # You're in crazy town; somehow you gave us a tensor
- # that wasn't a view, but had nonzero storage offset,
- # nontrivial strides (such that clone() couldn't
- # preserve them), or already aliases with another
- # tensor's storage. The most typical way to end
- # up here is with set_. So use set_ to bludgeon this
- # in.
- r_s = self.meta_storage(s, callback=callback)
- # NB: In principle, this should always work, but there
- # is some subtle difference in the autograd metadata
- # that means we will backprop the set_ call, even if
- # r is declared as an input to grad.
- # See https://github.com/pytorch/pytorch/issues/87956
- # for the reproducer.
- # NB: The in_kernel_invocation_manager here is necessary
- # for fake tensor. If we run the set_ call with fake
- # tensor on, r will improperly report that it is NOT a
- # meta tensor but a cpu tensor, and then the set_ call
- # will fail due to device mismatch. no_dispatch() is
- # not enough, because the fake tensor will still claim
- # to be a CPU tensor and you'll end up in the CPU
- # kernel. Arguably this is a hack; a cleaner way to
- # solve this is to have a FakeStorage concept which
- # would report it's CPU device--no problem now! But
- # this is difficult to do because we don't have storage
- # subclasses. Relevant test is
- # DynamicShapesFunctionTests::test_add_dynamic_shapes in
- # test/dynamo/test_dynamic_shapes.py
- maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
- from torch._subclasses.fake_tensor import (
- FakeTensor,
- in_kernel_invocation_manager,
- )
- if isinstance(r, FakeTensor):
- maybe_fake_mgr = in_kernel_invocation_manager(r.fake_mode)
- with maybe_fake_mgr, torch.no_grad():
- r.set_(r_s, storage_offset, sizes, strides)
- if safe_grad(t) is not None:
- from torch._dynamo.source import AttrSource
- r.grad = self.meta_tensor(
- safe_grad(t),
- shape_env,
- callback,
- source=AttrSource(source, "grad"),
- )
- torch._C._set_conj(r, t.is_conj())
- torch._C._set_neg(r, t.is_neg())
- # This can be skipped if necessary for performance reasons
- assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
- self.set_tensor_memo(t, r)
- return self.get_tensor_memo(t)
- def __call__(
- self,
- t,
- shape_env=None,
- *,
- callback=lambda t: t(),
- ignore_subclass=False,
- source=None,
- ):
- # TODO: zero tensors? We appear to have eliminated them by
- # excluding complex for now
- from torch._subclasses.fake_tensor import FakeTensor
- if (
- type(t) is torch.Tensor
- or type(t) is torch.nn.Parameter
- or (ignore_subclass and isinstance(t, torch.Tensor))
- or isinstance(t, FakeTensor)
- ):
- if any(
- [
- t.is_sparse_csr,
- t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
- t.is_quantized,
- t.is_nested,
- t._is_view() and t._base is not None and t._base.is_sparse,
- torch._is_functional_tensor(t),
- # these are supported in meta conversion but the fallbacks
- # don't work
- t.is_neg(),
- t.is_conj(),
- t.device.type in ("lazy"),
- # We need a way to test if a tensor is batched but there
- # is no official APi to do it
- # torch._C._is_batched(t),
- ]
- ):
- # TODO: sparse should support meta
- # NB technically to('meta') does work but our logging
- # instrumentation will see the meta conversions and the
- # tests all break so we just exclude this. In any case
- # the to conversion isn't really right anyhow.
- self.miss += 1
- return NotImplemented
- else:
- self.hit += 1
- # When ignoring subclasses, we treat the input tensor "as if" it
- # were a normal tensor and create a non-subclassed fake tensor
- # that, modulo type and attributes, resembles the original tensor.
- # This can be helpful if you're planning to simulate the subclassness
- # by hand, e.g., as is done in Dynamo
- ctx = contextlib.nullcontext()
- if ignore_subclass:
- ctx = torch._C.DisableTorchFunctionSubclass()
- with ctx:
- r = self.meta_tensor(
- t, shape_env=shape_env, callback=callback, source=source
- )
- # TODO: this is suspicious, now that we have callback argument
- if type(t) is torch.nn.Parameter:
- r = torch.nn.Parameter(r, requires_grad=r.requires_grad)
- return r
- elif torch.overrides.is_tensor_like(t):
- # Blindly converting tensor subclasses to meta can cause
- # unpredictable problems; e.g., FX tests will trace meta
- # tensors into their trace / some subclasses don't correctly
- # support meta. Trying to YOLO this is more trouble than it's
- # worth.
- self.miss += 1
- return NotImplemented
- else:
- # non-Tensor types don't count as hit or miss
- return t
- import torch._prims_common as utils
|