meta_utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import contextlib
  2. import warnings
  3. import weakref
  4. from typing import ContextManager, Optional
  5. import torch
  6. from torch._guards import Source
  7. from torch.multiprocessing.reductions import StorageWeakRef
  8. from torch.utils.weak import WeakIdRef
  9. def safe_is_leaf(t):
  10. try:
  11. return t.is_leaf
  12. except RuntimeError:
  13. # inference mode can trigger this
  14. return False
  15. def safe_grad(t):
  16. with warnings.catch_warnings():
  17. warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
  18. return t.grad
  19. def assert_eq(a, b):
  20. assert a == b, f"{a} != {b}"
  21. def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False):
  22. def go(m1, m2):
  23. assert_eq(m1.dtype, m2.dtype)
  24. if not skip_symbolic:
  25. assert_eq(m1.shape, m2.shape)
  26. assert_eq(m1.requires_grad, m2.requires_grad)
  27. assert_eq(m1.is_leaf, m2.is_leaf)
  28. assert_eq(m1.grad_fn is None, m2.grad_fn is None)
  29. assert_eq(m1.is_sparse, m2.is_sparse)
  30. assert_eq(m1.is_inference(), m2.is_inference())
  31. assert_eq(m1.is_conj(), m2.is_conj())
  32. assert_eq(m1.is_neg(), m2.is_neg())
  33. assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None)
  34. if safe_grad(m1) is not None:
  35. go(safe_grad(m1), safe_grad(m2))
  36. if m1.is_sparse:
  37. assert_eq(m1.dense_dim(), m2.dense_dim())
  38. assert_eq(m1.sparse_dim(), m2.sparse_dim())
  39. assert_eq(m1.is_coalesced(), m2.is_coalesced())
  40. else:
  41. if not skip_symbolic:
  42. assert_eq(m1.stride(), m2.stride())
  43. assert_eq(m1.storage_offset(), m2.storage_offset())
  44. assert_eq(m1._is_view(), m2._is_view())
  45. if m1._is_view():
  46. go(m1._base, m2._base)
  47. # TODO: test if is resizable (no direct query for this atm)
  48. # TODO: audit AutogradMeta to see if it matches
  49. # TODO: test forward AD
  50. return go(m1, m2)
  51. # This is a class for converting multiple tensors into meta tensors which
  52. # share the same view/storage structure. The operation model is you allocate
  53. # one of these, and then call it repeatedly on all the tensors you want to
  54. # convert. It's important to use the same object for tensors you want to
  55. # share storage because this is how we correlate shared storages to the same
  56. # meta storages. This class will hold weak references to cached tenosrs
  57. # and tensor storages.
  58. class MetaConverter:
  59. def __init__(self):
  60. self.storage_memo = {}
  61. self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
  62. self.maybe_storages_to_delete = []
  63. self.check_expired_frequency = 128
  64. self.check_expired_count = 0
  65. self.hit = 0
  66. self.miss = 0
  67. self.del_hook = None
  68. self.arg_cnt = 0
  69. def successful(self):
  70. return self.hit > 0 and self.miss == 0
  71. def check_for_expired_weak_storages(self):
  72. new_li = []
  73. stor_to_delete = []
  74. for obj in self.maybe_storages_to_delete:
  75. if not obj.expired():
  76. new_li.append(obj)
  77. else:
  78. stor_to_delete.append(obj)
  79. for obj in stor_to_delete:
  80. self.storage_memo.pop(obj, None)
  81. self.maybe_storages_to_delete = new_li
  82. # if for some reason we have aquired many storages which have not expired
  83. # even though a tensor with their storage has expired (aliasing or otherwise)
  84. # check for expired storages less often so as to bound the amount of work we
  85. # do checking for expired storages
  86. self.check_expired_frequency = max(
  87. self.check_expired_frequency, len(self.maybe_storages_to_delete)
  88. )
  89. def get_tensor_memo(self, t):
  90. return self.tensor_memo.get(WeakIdRef(t), None)
  91. def set_tensor_memo(self, t, v):
  92. # hold a weak ref to self, otherwise it will be kept alive
  93. # by the del_ten closure
  94. self_weak_ref = weakref.ref(self)
  95. if t.is_sparse or t.is_mkldnn:
  96. weak_st = None
  97. else:
  98. weak_st = StorageWeakRef(t._typed_storage())
  99. tensor_ref_key = WeakIdRef(t)
  100. def del_ten():
  101. # tensor outlives the converter
  102. self_ref = self_weak_ref()
  103. if self_ref is None:
  104. return
  105. # on shutdown, tensor_ref_key may not be in memo
  106. self_ref.tensor_memo.pop(tensor_ref_key, None)
  107. if weak_st and weak_st.expired():
  108. self_ref.storage_memo.pop(weak_st, None)
  109. elif weak_st is not None:
  110. # [expired-storages]
  111. # NB: even though the tensor has died,
  112. # the deallocation of its storage can take longer,
  113. # even when the storage has no other uses/views.
  114. # In this case, the StorageWeakRef object will be kept alive
  115. # longer than it needs to be, however the storage itself
  116. # will be deallocated. We retain the possibly dead storages
  117. # and periodically check if any of them are expired and
  118. # can be freed.
  119. self_ref.maybe_storages_to_delete.append(weak_st)
  120. weakref.finalize(t, del_ten)
  121. self.tensor_memo[tensor_ref_key] = v
  122. # NB: doesn't actually return a storage, because meta storage is
  123. # not supported
  124. def meta_storage(self, s, callback):
  125. # NB: TypedStorage is freshly allocated and cannot be used as hash
  126. # key index.
  127. # Use a Weak Ref to s in order to not leak memory
  128. swr = StorageWeakRef(s)
  129. if swr not in self.storage_memo:
  130. self.storage_memo[swr] = callback(
  131. lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
  132. ).untyped_storage()
  133. return self.storage_memo[swr]
  134. # This function assumes that it's possible to do the conversion
  135. # NB: name here is used in a conventional way by Dynamo; it corresponds
  136. # precisely to the Source.name() of the tensor we're fakeifying and
  137. # corresponds to a valid Python expression. When we construct sub-names
  138. # as part of this process, we will maintain this invariant! (Even though
  139. # other users of this may not need it this property to be upheld.)
  140. def meta_tensor(
  141. self, t, shape_env=None, callback=lambda t: t(), source: Optional[Source] = None
  142. ):
  143. if source is None:
  144. from torch._dynamo.source import ConstantSource
  145. # TODO: make a dedicated UnknownSource for this?
  146. source = ConstantSource(f"__unknown_tensor{len(self.tensor_memo)}")
  147. # This indicates you set no_dispatch() before calling into this
  148. # function. This is an error: we may be creating fake tensors and
  149. # will perform operations on them which need fake tensor mode to
  150. # be active. You will segfault if you are in a no_dispatch() block.
  151. assert not torch._C._dispatch_tls_local_exclude_set().has(
  152. torch._C.DispatchKey.Python
  153. )
  154. arg_cnt = self.arg_cnt
  155. self.arg_cnt += 1
  156. # When we make as_strided calls, we end up generating a guard
  157. # that the new as_strided tensor is in bounds for the old storage
  158. # for the base (since as_strided calls can "bust" out of their
  159. # bounding box.) This guard is unnecessary: if a user is able
  160. # to provide us a tensor with the view base setup this way, we
  161. # don't need to produce a guard, because the fact that they
  162. # were able to produce the view base means its in bounds.
  163. #
  164. # Now, ordinarily, this guard would be harmless. However, the
  165. # generated guard refers to variables bound on the base variable.
  166. # At the moment, Dynamo doesn't actually guard on x._base, because
  167. # according to Voz this results in a lot of spurious invalidations,
  168. # and also if the user doesn't directly make use of _base, its
  169. # pointless anyway (because programs should be parametric over
  170. # whether or not the input tensor is a view or not--unless you're
  171. # mutating the input, but that's a whole 'nother ballgame). So
  172. # for expediency, we suppress these guards so we don't have to
  173. # deal with this (yet, anyway.)
  174. #
  175. # NB: An old version of this code suppressed guards for ALL operations
  176. # happening during meta conversion, not just as_strided calls.
  177. # This is too aggressive: we do duck sizing and 0/1 simplification
  178. # as we allocate variables, and we do need to register guards for
  179. # these cases.
  180. maybe_suppress = contextlib.nullcontext
  181. if shape_env is not None:
  182. maybe_suppress = shape_env.suppress_guards
  183. make_symbolic = shape_env is not None
  184. def sym_sizes_strides_storage_offset(t):
  185. if make_symbolic:
  186. return shape_env.create_symbolic_sizes_strides_storage_offset(t, source)
  187. return (t.size(), t.stride(), t.storage_offset())
  188. # see expired-storages
  189. self.check_expired_count += 1
  190. if self.check_expired_count >= self.check_expired_frequency:
  191. self.check_for_expired_weak_storages()
  192. self.check_expired_count = 0
  193. if self.get_tensor_memo(t) is None:
  194. with torch.inference_mode(t.is_inference()):
  195. if t.is_sparse:
  196. assert shape_env is None, "symbolic on sparse NYI"
  197. is_leaf = safe_is_leaf(t)
  198. r = callback(
  199. lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
  200. t.sparse_dim(),
  201. t.dense_dim(),
  202. t.shape,
  203. dtype=t.dtype,
  204. layout=torch.sparse_coo,
  205. device="meta",
  206. )
  207. )
  208. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  209. # Note [is_coalesced is dispatched]
  210. # Strangely enough, is_coalesced() is a dispatched operator,
  211. # which means that it will get caught by fake tensor mode.
  212. # Ordinarily this would error, but there's some logic in
  213. # fake tensor ensure this doesn't happen.
  214. r._coalesced_(t.is_coalesced())
  215. if t.requires_grad:
  216. r.requires_grad = True
  217. if t.requires_grad and not is_leaf:
  218. with torch.enable_grad():
  219. r = r.clone()
  220. r._coalesced_(t.is_coalesced())
  221. elif t.is_mkldnn:
  222. is_leaf = safe_is_leaf(t)
  223. sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
  224. t
  225. )
  226. r = callback(
  227. lambda: torch.empty_strided(
  228. sizes, strides, dtype=t.dtype, device="meta"
  229. )
  230. )
  231. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  232. if t.requires_grad:
  233. r.requires_grad = True
  234. if t.requires_grad and not is_leaf:
  235. with torch.enable_grad():
  236. r = r.clone()
  237. elif t._is_view():
  238. # Construct views in two steps: recursively meta-fy their
  239. # base, and then create view(s) off that. NB: doing it
  240. # directly from storage is WRONG because this won't cause
  241. # version counters to get shared.
  242. assert t._is_view()
  243. from torch._dynamo.source import AttrSource
  244. base = self.meta_tensor(
  245. t._base, shape_env, callback, source=AttrSource(source, "_base")
  246. )
  247. def is_c_of_r(complex_dtype, real_dtype):
  248. return (
  249. utils.is_complex_dtype(complex_dtype)
  250. and utils.corresponding_real_dtype(complex_dtype)
  251. == real_dtype
  252. )
  253. # In some situations, MetaConverter may be called in a
  254. # context where autograd is disabled. For the _is_view
  255. # assert to pass, we have to setup the autograd view
  256. # metadata anyway. Do this by reenabling the
  257. # ADInplaceOrView key. This is kind of a hack.
  258. old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
  259. torch._C.DispatchKey.ADInplaceOrView
  260. )
  261. torch._C._dispatch_tls_set_dispatch_key_excluded(
  262. torch._C.DispatchKey.ADInplaceOrView, False
  263. )
  264. try:
  265. if base.dtype == t.dtype:
  266. pass
  267. elif is_c_of_r(base.dtype, t.dtype):
  268. base = torch.view_as_real(base)
  269. elif is_c_of_r(t.dtype, base.dtype):
  270. base = torch.view_as_complex(base)
  271. else:
  272. # This is not guaranteed to succeed. If it fails, it
  273. # means there is another dtype-converting view function
  274. # that hasn't been handled here
  275. base = base.view(t.dtype)
  276. # This is very tricky. Naively, you might expect this
  277. # to hold:
  278. #
  279. # if t.requires_grad and not safe_is_leaf(t)
  280. # assert t._base.requires_grad
  281. #
  282. # But it's not true! As you can see in the following
  283. # program:
  284. #
  285. # x = torch.zeros(4)
  286. # y = x.view(1, 4)
  287. # y.requires_grad = True
  288. # z = y.view(1, 1, 4)
  289. # assert z._base is x
  290. #
  291. # So we may have to do *two* views out of the base to
  292. # recreate this situation.
  293. (
  294. sizes,
  295. strides,
  296. storage_offset,
  297. ) = sym_sizes_strides_storage_offset(t)
  298. if safe_is_leaf(t):
  299. # Leaf views that track view metadata are created by
  300. # creating a view inside a no_grad block
  301. with torch.no_grad(), maybe_suppress():
  302. r = base.as_strided(sizes, strides, storage_offset)
  303. # As it's a leaf, we can directly assign requires_grad
  304. r.requires_grad = t.requires_grad
  305. else:
  306. if t._base.requires_grad == t.requires_grad:
  307. # Easy case, just run the view op
  308. with torch.enable_grad(), maybe_suppress():
  309. r = base.as_strided(sizes, strides, storage_offset)
  310. else:
  311. # Obscure case. Create a leaf view and give it the
  312. # correct requires_grad, then do the final view.
  313. # NB: Can't have a non-leaf without requiring grad!
  314. assert t.requires_grad
  315. with torch.no_grad():
  316. mid = base.view(base.shape)
  317. mid.requires_grad = t.requires_grad
  318. with torch.enable_grad(), maybe_suppress():
  319. r = mid.as_strided(sizes, strides, storage_offset)
  320. finally:
  321. torch._C._dispatch_tls_set_dispatch_key_excluded(
  322. torch._C.DispatchKey.ADInplaceOrView, old_exclude
  323. )
  324. else:
  325. is_leaf = safe_is_leaf(t)
  326. sizes, strides, storage_offset = sym_sizes_strides_storage_offset(t)
  327. r = callback(
  328. lambda: torch.empty_strided(
  329. sizes, strides, dtype=t.dtype, device="meta"
  330. )
  331. )
  332. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  333. if t.requires_grad:
  334. r.requires_grad = t.requires_grad
  335. if not is_leaf:
  336. # Fake up some autograd history.
  337. with torch.enable_grad():
  338. # preserve_format is the default, but we want to
  339. # emphasize how important it is to preserve
  340. # format here
  341. r = r.clone(memory_format=torch.preserve_format)
  342. s = t.untyped_storage()
  343. swr = StorageWeakRef(s)
  344. if (
  345. swr not in self.storage_memo
  346. and r.stride() == strides
  347. and r.storage_offset() == storage_offset
  348. ):
  349. # You're normal and happy, install the fresh storage into the memo
  350. self.storage_memo[swr] = r.untyped_storage()
  351. else:
  352. # You're in crazy town; somehow you gave us a tensor
  353. # that wasn't a view, but had nonzero storage offset,
  354. # nontrivial strides (such that clone() couldn't
  355. # preserve them), or already aliases with another
  356. # tensor's storage. The most typical way to end
  357. # up here is with set_. So use set_ to bludgeon this
  358. # in.
  359. r_s = self.meta_storage(s, callback=callback)
  360. # NB: In principle, this should always work, but there
  361. # is some subtle difference in the autograd metadata
  362. # that means we will backprop the set_ call, even if
  363. # r is declared as an input to grad.
  364. # See https://github.com/pytorch/pytorch/issues/87956
  365. # for the reproducer.
  366. # NB: The in_kernel_invocation_manager here is necessary
  367. # for fake tensor. If we run the set_ call with fake
  368. # tensor on, r will improperly report that it is NOT a
  369. # meta tensor but a cpu tensor, and then the set_ call
  370. # will fail due to device mismatch. no_dispatch() is
  371. # not enough, because the fake tensor will still claim
  372. # to be a CPU tensor and you'll end up in the CPU
  373. # kernel. Arguably this is a hack; a cleaner way to
  374. # solve this is to have a FakeStorage concept which
  375. # would report it's CPU device--no problem now! But
  376. # this is difficult to do because we don't have storage
  377. # subclasses. Relevant test is
  378. # DynamicShapesFunctionTests::test_add_dynamic_shapes in
  379. # test/dynamo/test_dynamic_shapes.py
  380. maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
  381. from torch._subclasses.fake_tensor import (
  382. FakeTensor,
  383. in_kernel_invocation_manager,
  384. )
  385. if isinstance(r, FakeTensor):
  386. maybe_fake_mgr = in_kernel_invocation_manager(r.fake_mode)
  387. with maybe_fake_mgr, torch.no_grad():
  388. r.set_(r_s, storage_offset, sizes, strides)
  389. if safe_grad(t) is not None:
  390. from torch._dynamo.source import AttrSource
  391. r.grad = self.meta_tensor(
  392. safe_grad(t),
  393. shape_env,
  394. callback,
  395. source=AttrSource(source, "grad"),
  396. )
  397. torch._C._set_conj(r, t.is_conj())
  398. torch._C._set_neg(r, t.is_neg())
  399. # This can be skipped if necessary for performance reasons
  400. assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
  401. self.set_tensor_memo(t, r)
  402. return self.get_tensor_memo(t)
  403. def __call__(
  404. self,
  405. t,
  406. shape_env=None,
  407. *,
  408. callback=lambda t: t(),
  409. ignore_subclass=False,
  410. source=None,
  411. ):
  412. # TODO: zero tensors? We appear to have eliminated them by
  413. # excluding complex for now
  414. from torch._subclasses.fake_tensor import FakeTensor
  415. if (
  416. type(t) is torch.Tensor
  417. or type(t) is torch.nn.Parameter
  418. or (ignore_subclass and isinstance(t, torch.Tensor))
  419. or isinstance(t, FakeTensor)
  420. ):
  421. if any(
  422. [
  423. t.is_sparse_csr,
  424. t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
  425. t.is_quantized,
  426. t.is_nested,
  427. t._is_view() and t._base is not None and t._base.is_sparse,
  428. torch._is_functional_tensor(t),
  429. # these are supported in meta conversion but the fallbacks
  430. # don't work
  431. t.is_neg(),
  432. t.is_conj(),
  433. t.device.type in ("lazy"),
  434. # We need a way to test if a tensor is batched but there
  435. # is no official APi to do it
  436. # torch._C._is_batched(t),
  437. ]
  438. ):
  439. # TODO: sparse should support meta
  440. # NB technically to('meta') does work but our logging
  441. # instrumentation will see the meta conversions and the
  442. # tests all break so we just exclude this. In any case
  443. # the to conversion isn't really right anyhow.
  444. self.miss += 1
  445. return NotImplemented
  446. else:
  447. self.hit += 1
  448. # When ignoring subclasses, we treat the input tensor "as if" it
  449. # were a normal tensor and create a non-subclassed fake tensor
  450. # that, modulo type and attributes, resembles the original tensor.
  451. # This can be helpful if you're planning to simulate the subclassness
  452. # by hand, e.g., as is done in Dynamo
  453. ctx = contextlib.nullcontext()
  454. if ignore_subclass:
  455. ctx = torch._C.DisableTorchFunctionSubclass()
  456. with ctx:
  457. r = self.meta_tensor(
  458. t, shape_env=shape_env, callback=callback, source=source
  459. )
  460. # TODO: this is suspicious, now that we have callback argument
  461. if type(t) is torch.nn.Parameter:
  462. r = torch.nn.Parameter(r, requires_grad=r.requires_grad)
  463. return r
  464. elif torch.overrides.is_tensor_like(t):
  465. # Blindly converting tensor subclasses to meta can cause
  466. # unpredictable problems; e.g., FX tests will trace meta
  467. # tensors into their trace / some subclasses don't correctly
  468. # support meta. Trying to YOLO this is more trouble than it's
  469. # worth.
  470. self.miss += 1
  471. return NotImplemented
  472. else:
  473. # non-Tensor types don't count as hit or miss
  474. return t
  475. import torch._prims_common as utils