fake_tensor.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422
  1. import contextlib
  2. import functools
  3. import itertools
  4. import logging
  5. import os
  6. import weakref
  7. from dataclasses import dataclass
  8. from functools import partial
  9. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
  10. from weakref import ReferenceType
  11. import torch
  12. from torch._guards import Source
  13. from torch._ops import OpOverload
  14. from torch._prims_common import (
  15. elementwise_dtypes,
  16. ELEMENTWISE_TYPE_PROMOTION_KIND,
  17. is_float_dtype,
  18. is_integer_dtype,
  19. )
  20. from torch._subclasses.meta_utils import MetaConverter
  21. from torch.fx.operator_schemas import normalize_function
  22. from torch.multiprocessing.reductions import StorageWeakRef
  23. from torch.overrides import TorchFunctionMode
  24. from torch.utils._mode_utils import no_dispatch
  25. from torch.utils._python_dispatch import TorchDispatchMode
  26. from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
  27. from torch.utils._stats import count, count_label
  28. from torch.utils.weak import WeakIdRef
  29. log = logging.getLogger(__name__)
  30. pytree = torch.utils._pytree
  31. T = TypeVar("T")
  32. TensorWeakRef = Any
  33. aten = torch._ops.ops.aten
  34. CONSTANT_NUMEL_LIMIT = 1
  35. RECURSION_COUNT = 0
  36. # Small helper that increments recursion count, and
  37. # resets it when the object goes out of scope. Useful
  38. # if you don't want to increase indentation which is
  39. # what a context manager would do.
  40. class IncrementRecursionCount:
  41. def __init__(self):
  42. global RECURSION_COUNT
  43. RECURSION_COUNT += 1
  44. def __del__(self):
  45. global RECURSION_COUNT
  46. RECURSION_COUNT -= 1
  47. @dataclass
  48. class UnsupportedFakeTensorException(RuntimeError):
  49. reason: str
  50. @dataclass
  51. class DynamicOutputShapeException(RuntimeError):
  52. func: OpOverload
  53. @dataclass
  54. class DataDependentOutputException(RuntimeError):
  55. func: OpOverload
  56. _device_not_kwarg_ops = (
  57. aten._resize_output_.default,
  58. aten._nested_tensor_from_tensor_list.default,
  59. aten._nested_tensor_from_tensor_list.out,
  60. aten.pin_memory.default,
  61. aten.is_pinned.default,
  62. aten.to.device,
  63. aten.to.prim_Device,
  64. aten._pin_memory.default,
  65. aten._pin_memory.out,
  66. aten._resize_output.default,
  67. aten._resize_output.out,
  68. )
  69. # this op is never actually used
  70. _non_kwarg_device_constructors = (aten._list_to_tensor,)
  71. def contains_tensor_types(type):
  72. tensor_type = torch._C.TensorType.get()
  73. return type.isSubtypeOf(tensor_type) or any(
  74. contains_tensor_types(e) for e in type.containedTypes()
  75. )
  76. _like_tensor_constructors = (
  77. aten.empty_like.default,
  78. aten.empty_like.out,
  79. aten.full_like.default,
  80. aten.full_like.out,
  81. aten.ones_like.default,
  82. aten.ones_like.out,
  83. aten.rand_like.default,
  84. aten.rand_like.out,
  85. aten.randn_like.default,
  86. aten.randn_like.out,
  87. aten.randint_like.default,
  88. aten.randint_like.out,
  89. aten.randint_like.low_dtype,
  90. aten.randint_like.low_dtype_out,
  91. aten.zeros_like.default,
  92. aten.zeros_like.out,
  93. aten.new_empty.default,
  94. aten.new_empty.out,
  95. aten.new_empty_strided.default,
  96. aten.new_empty_strided.out,
  97. aten.new_full.default,
  98. aten.new_full.out,
  99. aten.new_zeros.default,
  100. aten.new_zeros.out,
  101. aten.new_ones.default,
  102. aten.new_ones.out,
  103. )
  104. @functools.lru_cache(None)
  105. def _is_tensor_constructor(func: OpOverload):
  106. assert isinstance(func, OpOverload)
  107. schema = func._schema
  108. if any(contains_tensor_types(arg.type) for arg in schema.arguments):
  109. return False
  110. # TODO: no real reason to restrict multiple outputs
  111. return (
  112. len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
  113. )
  114. @functools.lru_cache(None)
  115. def get_schema_info(func):
  116. return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined]
  117. # many of the decompositions registered to torch/_prims do not at the moment model
  118. # aliasing or strides, so as an incremental step, just enable the decompositions in
  119. # torch/_decomp/decompositions.py.
  120. # decomps are used for aot autograd tracing so we would like to unify on their
  121. # implementation and add additional testing to them
  122. @functools.lru_cache(None)
  123. def torch_decomp_decompositions(func):
  124. from torch._decomp import decomposition_table
  125. decompositions = torch._decomp.decompositions
  126. decomp_attrs = [getattr(decompositions, attr) for attr in dir(decompositions)]
  127. return decomposition_table[func] in decomp_attrs
  128. def tree_flatten_only(ty: Type[T], pytree: PyTree):
  129. flat_vals, _ = tree_flatten(pytree)
  130. return [elem for elem in flat_vals if isinstance(elem, ty)]
  131. # Similar to `MetaConverter`, this is a class for converting
  132. # multiple tensors into fake tensors which share the same view/storage
  133. # structure. Like `MetaConverter`, it uses `WeakIdRef` to
  134. # hold a weak reference for all memoized tensors.
  135. class FakeTensorConverter:
  136. @property
  137. def tensor_memo(self):
  138. return self.meta_converter.tensor_memo
  139. meta_converter: MetaConverter
  140. constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]]
  141. def __init__(self):
  142. self.meta_converter = MetaConverter()
  143. # map from to storage to corresponding constant tensors
  144. self.constant_storage_mapping = {}
  145. def add_constant_storage_mapping(self, fake_tensor):
  146. # when you have a constant, aliased tensor:
  147. # const_tensor.add_(torch.rand([1]))
  148. # all aliases of it must become no longer const
  149. assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
  150. weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())
  151. # we need a map from a weak storage to all of its corresponding
  152. # constant tensors. python doesn't have the weak value equivalent
  153. # of defaultdict(list), so we are using a WeakValueDictionary as one
  154. if weak_st not in self.constant_storage_mapping:
  155. self.constant_storage_mapping[weak_st] = []
  156. self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor))
  157. def invalidate_constant_aliases(self, tensor):
  158. assert not isinstance(tensor, FakeTensor)
  159. weak_st = StorageWeakRef(tensor._typed_storage())
  160. if weak_st not in self.constant_storage_mapping:
  161. return
  162. for weak_tensor_ref in self.constant_storage_mapping[weak_st]:
  163. ten = weak_tensor_ref()
  164. if ten is not None:
  165. ten._fix_weakref()
  166. ten.constant = None
  167. del self.constant_storage_mapping[weak_st]
  168. def _get_memo(self, t):
  169. if WeakIdRef(t) in self.tensor_memo:
  170. out = self.tensor_memo[WeakIdRef(t)]
  171. out._fix_weakref()
  172. return out
  173. return None
  174. def set_tensor_memo(self, t, v):
  175. th = WeakIdRef(t)
  176. # hold a weak ref to self, otherwise it will be kept alive
  177. # by the del_ten closure
  178. self_weak_ref = weakref.ref(self)
  179. def del_ten():
  180. self_ref = self_weak_ref()
  181. if self_ref is None:
  182. return
  183. # on shutdown, th may not be in memo
  184. self_ref.tensor_memo.pop(th, None)
  185. weakref.finalize(t, del_ten)
  186. self.tensor_memo[th] = v
  187. def from_real_tensor(
  188. self,
  189. fake_mode,
  190. t,
  191. make_constant=False,
  192. shape_env=None,
  193. ignore_subclass=False,
  194. *,
  195. source=None,
  196. ):
  197. maybe_memo = self._get_memo(t)
  198. if maybe_memo is not None:
  199. return maybe_memo
  200. existing_device = t.device
  201. # not yet supported in metatensors
  202. if t.is_quantized:
  203. raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
  204. if type(t) is torch.nn.Parameter:
  205. assert not make_constant
  206. def mk_fake_tensor(make_meta_t):
  207. # NB: don't use in_kernel_invocation_manager. to
  208. # ensure FakeTensor can internally do constant computation
  209. # as necessary. Invocation manager is "more correct" as
  210. # it works for more operators in make_meta_t, but
  211. # invariant is that make_meta_t only calls factories
  212. # for which it is not strictly necessary to use the
  213. # invocation manager (I think!)
  214. with no_dispatch():
  215. return FakeTensor(
  216. fake_mode,
  217. make_meta_t(),
  218. existing_device,
  219. constant=t if make_constant else None,
  220. )
  221. out = self.meta_converter(
  222. t,
  223. shape_env=shape_env,
  224. callback=mk_fake_tensor,
  225. ignore_subclass=ignore_subclass,
  226. source=source,
  227. )
  228. if out is NotImplemented:
  229. raise UnsupportedFakeTensorException("meta converter nyi")
  230. if make_constant:
  231. self.add_constant_storage_mapping(out)
  232. # NB: meta_converter set the memo
  233. return out
  234. # If you specify the device, it MUST be a meta tensor.
  235. def from_meta_and_device(self, fake_mode, t, device):
  236. assert (
  237. t.device.type == "meta"
  238. ), f"tensor's device must be `meta`, got {t.device.type} instead"
  239. maybe_memo = self._get_memo(t)
  240. if maybe_memo is not None:
  241. return maybe_memo
  242. out = FakeTensor(fake_mode, t, device)
  243. self.set_tensor_memo(t, out)
  244. return out
  245. # You can have a real tensor that you need to convert into a fake tensor.
  246. # If you have a meta tensor already, call from_meta_and_device.
  247. #
  248. # You're allowed to pass a meta tensor to be turned into a fake
  249. # tensor; although an odd thing to do, this can occur if you're doing
  250. # cross ref testing and the inner test is already operating on meta tensors.
  251. def __call__(
  252. self,
  253. fake_mode,
  254. t,
  255. *,
  256. make_constant=False,
  257. shape_env=None,
  258. ignore_subclass=False,
  259. source=None,
  260. ):
  261. return self.from_real_tensor(
  262. fake_mode,
  263. t,
  264. make_constant,
  265. shape_env=shape_env,
  266. ignore_subclass=ignore_subclass,
  267. source=source,
  268. )
  269. op_implementations = []
  270. def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
  271. def impl_decorator(op_impl):
  272. global op_implementations
  273. if isinstance(run_impl_check, OpOverload):
  274. op_implementations.append((lambda func: func == run_impl_check, op_impl))
  275. else:
  276. op_implementations.append((run_impl_check, op_impl))
  277. return op_impl
  278. return impl_decorator
  279. @register_op_impl(
  280. lambda func: (_is_tensor_constructor(func) or func in _like_tensor_constructors)
  281. )
  282. def constructors(fake_mode, func, *args, **kwargs):
  283. assert func not in _non_kwarg_device_constructors
  284. _, new_kwargs = normalize_function(
  285. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  286. )
  287. if func in _like_tensor_constructors:
  288. default_device = new_kwargs["input"].device
  289. # TODO: file issue
  290. args = (new_kwargs.pop("input"),)
  291. else:
  292. # cpu is default device if none is specified
  293. default_device = torch.device("cpu")
  294. args = ()
  295. out_device = new_kwargs.pop("device", None)
  296. out_device = out_device if out_device is not None else default_device
  297. new_kwargs["device"] = torch.device("meta")
  298. # _like constructors have fake tensor inputs (maybe this causes the non-like
  299. # to fail? hmmm)
  300. with in_kernel_invocation_manager(fake_mode):
  301. r = func(*args, **new_kwargs)
  302. return FakeTensor(fake_mode, r, out_device)
  303. @register_op_impl(lambda func: func in (aten.to.prim_Device, aten.to.device))
  304. def non_kwarg_to(fake_mode, func, *args, **kwargs):
  305. _, new_kwargs = normalize_function(
  306. func, args, kwargs, normalize_to_only_use_kwargs=True
  307. )
  308. input_device = new_kwargs["device"]
  309. out_device = input_device if input_device else new_kwargs["input"].device
  310. new_kwargs["device"] = torch.device("meta")
  311. inp = new_kwargs.pop("input")
  312. with in_kernel_invocation_manager(fake_mode):
  313. r = func(inp, **new_kwargs)
  314. # TODO: I think this does the wrong thing if r is inp
  315. return fake_mode.fake_tensor_converter.from_meta_and_device(
  316. fake_mode, r, out_device
  317. )
  318. # Dont default to default device handling,
  319. # since the device of `the_template` is ignored
  320. @register_op_impl(aten.resize_as_.default)
  321. def resize_as_(fake_mode, func, *args, **kwargs):
  322. with in_kernel_invocation_manager(fake_mode):
  323. return func(*args, **kwargs)
  324. @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
  325. def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
  326. # TODO: remove me
  327. return constructors(fake_mode, func, *args, **kwargs)
  328. # index.Tensor data-dependent in only some conditions
  329. @register_op_impl(
  330. lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
  331. and func != aten.index.Tensor
  332. )
  333. def dyn_shape(fake_mode, func, *args, **kwargs):
  334. raise DynamicOutputShapeException(func)
  335. @register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
  336. def local_scalar_dense(fake_mode, func, arg):
  337. if fake_mode.shape_env is None:
  338. # Without symints/symfloats, cannot handle this
  339. raise DataDependentOutputException(func)
  340. if is_float_dtype(arg.dtype):
  341. return fake_mode.shape_env.create_unbacked_symfloat()
  342. elif is_integer_dtype(arg.dtype):
  343. return fake_mode.shape_env.create_unbacked_symint()
  344. else:
  345. raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
  346. # NB: this must be ordered after local_scalar_dense
  347. @register_op_impl(
  348. lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined]
  349. )
  350. def data_dep(fake_mode, func, *args, **kwargs):
  351. raise DataDependentOutputException(func)
  352. # Bool Indices get Expanded as Masks
  353. # See: IndexingUtils.h:expandTensors
  354. def check_no_bool_index_tensors(func, self, indices):
  355. for index in indices:
  356. if index is not None and index.dtype in (torch.bool, torch.uint8):
  357. raise DynamicOutputShapeException(func)
  358. def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
  359. _, new_kwargs = normalize_function(
  360. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  361. )
  362. out_device = new_kwargs["input"].device
  363. with in_kernel_invocation_manager(fake_mode):
  364. out = func(*args, **kwargs)
  365. return FakeTensor(fake_mode, out, out_device)
  366. # Dont default to default device handling,
  367. # Since op can take in non-zero sized cpu
  368. # index tensors with cuda self
  369. @register_op_impl(aten.index.Tensor)
  370. def index_tensor(fake_mode, func, *args, **kwargs):
  371. # dynamic shape op if indices are bool/uint8
  372. check_no_bool_index_tensors(func, *args, **kwargs)
  373. return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  374. # takes in multiple-devices, dont default to default device handling
  375. @register_op_impl(aten.index_put.default)
  376. def index_put(fake_mode, func, *args, **kwargs):
  377. return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  378. # same with index_put, but return the input
  379. @register_op_impl(aten.index_put_.default)
  380. def index_put_(fake_mode, func, *args, **kwargs):
  381. with in_kernel_invocation_manager(fake_mode):
  382. out = func(*args, **kwargs)
  383. _, new_kwargs = normalize_function(
  384. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  385. )
  386. return new_kwargs["input"]
  387. @register_op_impl(lambda fn: fn in _device_not_kwarg_ops)
  388. def nyi(fake_mode, func, *args, **kwargs):
  389. assert func not in _device_not_kwarg_ops, f"NYI: {func}"
  390. @register_op_impl(
  391. lambda func: func in (aten.convolution.default, aten.convolution_backward.default)
  392. )
  393. def conv(fake_mode, func, *args, **kwargs):
  394. _, kwargs = normalize_function(
  395. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  396. )
  397. device = kwargs["input"].fake_device
  398. # need to re-enable mode so the tensors report fake device
  399. with fake_mode:
  400. # if the input is unsqueezed is done in Convolution.cpp we get segfault
  401. k = kwargs["weight"].ndim
  402. if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
  403. mem_fmt = None
  404. else:
  405. if func is aten.convolution.default:
  406. conv_backend = torch._C._select_conv_backend(**kwargs)
  407. else:
  408. conv_backend = torch._C._select_conv_backend(
  409. kwargs["input"],
  410. kwargs["weight"],
  411. bias=None,
  412. stride=kwargs["stride"],
  413. padding=kwargs["padding"],
  414. dilation=kwargs["dilation"],
  415. transposed=kwargs["transposed"],
  416. output_padding=kwargs["output_padding"],
  417. groups=kwargs["groups"],
  418. bias_sizes=kwargs["bias_sizes"],
  419. )
  420. mem_fmt = torch._C._conv_determine_backend_memory_format(
  421. kwargs["input"], kwargs["weight"], conv_backend
  422. )
  423. def convert(t, mem_fmt):
  424. if t is None:
  425. return t
  426. if mem_fmt is not None:
  427. t = t.to(memory_format=mem_fmt)
  428. return FakeTensor(fake_mode, t, device)
  429. with in_kernel_invocation_manager(fake_mode):
  430. out = func(**kwargs)
  431. if func is aten.convolution.default:
  432. return convert(out, mem_fmt)
  433. else:
  434. return (
  435. convert(out[0], mem_fmt),
  436. convert(out[1], mem_fmt),
  437. convert(out[2], None),
  438. )
  439. FAST_OP_IMPLEMENTATIONS = {}
  440. # Unlike register_op_impl, these don't do the slow iteration for
  441. # run_impl_check, and these run BEFORE decompositions
  442. def register_fast_op_impl(func: OpOverload):
  443. def impl_decorator(op_impl):
  444. FAST_OP_IMPLEMENTATIONS[func] = op_impl
  445. return op_impl
  446. return impl_decorator
  447. # infer_size_impl in ExpandUtils
  448. def infer_size(a, b):
  449. dimsA = len(a)
  450. dimsB = len(b)
  451. ndim = max(dimsA, dimsB)
  452. expandedSizes = [0] * ndim
  453. for i in range(ndim - 1, -1, -1):
  454. offset = ndim - 1 - i
  455. dimA = dimsA - 1 - offset
  456. dimB = dimsB - 1 - offset
  457. sizeA = a[dimA] if dimA >= 0 else 1
  458. sizeB = b[dimB] if dimB >= 0 else 1
  459. if not (sizeA == sizeB or sizeA == 1 or sizeB == 1):
  460. raise RuntimeError(
  461. f"The size of tensor a ({sizeA}) "
  462. f"must match the size of tensor b ({sizeB}) "
  463. f"at non-singleton dimension {i})"
  464. )
  465. expandedSizes[i] = sizeB if sizeA == 1 else sizeA
  466. return tuple(expandedSizes)
  467. def make_fast_binary_impl(slow_ref):
  468. def fast_binary_impl(mode, *args, **kwargs):
  469. def slow(msg):
  470. count_label(f"slow {msg}")
  471. with mode:
  472. return slow_ref(*args, **kwargs)
  473. count_label("attempt fast")
  474. # Fast path (based off of TensorIterator fast path).
  475. # Unfortunately, there is no way to easily deduplicate
  476. # this with either the TensorIterator C++ implementation
  477. # (which we don't want to SymIntify, and also the algorithm
  478. # here is slightly different from TensorIterator to allow
  479. # for broadcasting), nor the PrimTorch implementation
  480. # (which does not actually implement a fast path.)
  481. operands = args
  482. # compute_shape
  483. has_scalars = False
  484. has_tensors = False
  485. final_shape = None
  486. for op in operands:
  487. shape = op.shape if isinstance(op, torch.Tensor) else ()
  488. if len(shape) == 0:
  489. has_scalars = True
  490. else:
  491. has_tensors = True
  492. if final_shape is None:
  493. final_shape = shape
  494. # TODO: Minor optimization: track if the shapes
  495. # were equal so you can skip the equality check
  496. # below if unnecessary
  497. final_shape = infer_size(final_shape, shape)
  498. assert final_shape is not None
  499. # Do some extra safety checks to see if the output
  500. # stride is obvious
  501. for op in operands:
  502. if isinstance(op, torch.Tensor) and op.shape == final_shape:
  503. break
  504. else:
  505. return slow("both tensors nontrivially broadcast")
  506. # compute_types
  507. cpu = torch.device("cpu")
  508. common_device = cpu
  509. common_dtype = None
  510. output_dtype = None
  511. has_different_input_dtypes = False
  512. for op in operands:
  513. if not isinstance(op, torch.Tensor):
  514. # Use elementwise_dtypes for the tricky case
  515. has_different_input_dtypes = True
  516. continue
  517. if common_device == cpu and not op.device.type == "cpu":
  518. common_device = op.device
  519. # Slightly simplified here as target_dtype cannot vary
  520. if common_dtype is None:
  521. common_dtype = op.dtype
  522. elif common_dtype != op.dtype:
  523. has_different_input_dtypes = True
  524. if has_different_input_dtypes:
  525. # compute promotion
  526. # TODO: we don't need the compute type
  527. _, common_dtype = elementwise_dtypes(
  528. *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  529. )
  530. # check all tensors on same device
  531. # cpu scalars are assumed allow
  532. current_cpu_scalars_on_non_cpu = 0
  533. max_cpu_scalars_on_non_cpu = 1 # hard coded atm
  534. for op in operands:
  535. if not isinstance(op, torch.Tensor):
  536. continue
  537. if common_device != cpu and op.dim() == 0 and op.device == cpu:
  538. if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
  539. return slow("error")
  540. current_cpu_scalars_on_non_cpu += 1
  541. elif op.device != common_device:
  542. return slow("error")
  543. # compute_fast_setup_type
  544. is_contiguous = True
  545. is_channels_last = True
  546. # TODO: is_non-overlapping_and_dense (not bound from Python
  547. # no inplace, no out, everything defined
  548. for op in operands:
  549. if not isinstance(op, torch.Tensor):
  550. continue
  551. is_contiguous = is_contiguous and op.is_contiguous(
  552. memory_format=torch.contiguous_format
  553. )
  554. is_channels_last = is_channels_last and op.is_contiguous(
  555. memory_format=torch.channels_last
  556. )
  557. if is_contiguous:
  558. # do contiguous
  559. count_label("fast is_contiguous")
  560. return FakeTensor(
  561. mode,
  562. torch.empty(
  563. final_shape,
  564. dtype=common_dtype,
  565. device="meta",
  566. memory_format=torch.contiguous_format,
  567. ),
  568. device=common_device,
  569. )
  570. if is_channels_last:
  571. count_label("fast channels_last")
  572. # do channels last
  573. return FakeTensor(
  574. mode,
  575. torch.empty(
  576. final_shape,
  577. dtype=common_dtype,
  578. device="meta",
  579. memory_format=torch.channels_last,
  580. ),
  581. device=common_device,
  582. )
  583. return slow("no contiguity match")
  584. return fast_binary_impl
  585. @functools.lru_cache(None)
  586. def get_fast_op_impls():
  587. import torch._refs
  588. register_fast_op_impl(torch.ops.aten.add.Tensor)(
  589. make_fast_binary_impl(torch._refs.add)
  590. )
  591. register_fast_op_impl(torch.ops.aten.sub.Tensor)(
  592. make_fast_binary_impl(torch._refs.sub)
  593. )
  594. register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
  595. register_fast_op_impl(torch.ops.aten.div.Tensor)(
  596. make_fast_binary_impl(torch._refs.div)
  597. )
  598. return FAST_OP_IMPLEMENTATIONS
  599. @contextlib.contextmanager
  600. def in_kernel_invocation_manager(fake_mode):
  601. # See: note [Fake Tensor Dispatch Keys]
  602. prev_in_kernel = fake_mode.in_kernel_invocation
  603. meta_in_tls = torch._C._meta_in_tls_dispatch_include()
  604. assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
  605. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
  606. fake_mode.in_kernel_invocation = True
  607. torch._C._set_meta_in_tls_dispatch_include(True)
  608. try:
  609. yield
  610. finally:
  611. fake_mode.in_kernel_invocation = prev_in_kernel
  612. torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
  613. del guard
  614. # Return if the function allows Python numbers to bind to Tensors
  615. def should_allow_numbers_as_tensors(func: OpOverload):
  616. return torch._C._should_allow_numbers_as_tensors(
  617. func.name().split("::")[-1].split(".")[0]
  618. )
  619. class FakeTensorConfig:
  620. debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", False)
  621. class FakeTensor(torch.Tensor):
  622. """
  623. Meta tensors give you the ability to run PyTorch code without having to
  624. actually do computation through tensors allocated on a `meta` device.
  625. Because the device is `meta`, meta tensors do not model device propagation.
  626. FakeTensor extends MetaTensors to also carry an additional `fake_device`
  627. which tracks devices that would have been used.
  628. """
  629. fake_device: torch.device
  630. fake_mode: "FakeTensorMode"
  631. constant: Optional[torch.Tensor]
  632. @property
  633. def device(self):
  634. if self.fake_mode.in_kernel_invocation:
  635. return torch.device("meta")
  636. else:
  637. return self.fake_device
  638. # Note: [Fake Tensor Dispatch Keys]
  639. # In order to model the behavior of device-specific autocast
  640. # and autograd logic, we update the dispatch keys of FakeTensors
  641. # to reflect their fake device. This includes the BackendComponent
  642. # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent
  643. # related Autocast and Autograd keys. __torch__dispatch__ sits below
  644. # Autocast and Autograd, and is only invoked when we are at the
  645. # kernel for the BackendComponent. Then, we add Meta to the
  646. # thread-local dispatch include set to hit the meta kernel
  647. # instead of the kernel of the BackendComponent for the fake device.
  648. # The `device_for_backend_keys` does that below
  649. @staticmethod
  650. def __new__(cls, fake_mode, elem, device, constant=None):
  651. self = torch.Tensor._make_subclass(
  652. cls,
  653. elem,
  654. elem.requires_grad,
  655. dispatch_device=True,
  656. device_for_backend_keys=device,
  657. )
  658. assert elem.device.type == "meta", elem.device.type
  659. device = device if isinstance(device, torch.device) else torch.device(device)
  660. # NB: it is fine, if a little confusing, for device to be meta
  661. # (we are faking a meta tensor in that case). However, it often
  662. # indicates some sort of confusion (e.g., you accidentally passed
  663. # in a meta tensor when you should have passed in the real tensor).
  664. # So by default we disallow meta, and if you are working in a situation
  665. # where it is helpful (e.g., crossref testing) you can turn it back
  666. # on
  667. if not fake_mode.allow_meta:
  668. assert device.type != "meta"
  669. # normalize cuda device.
  670. if device.type == "cuda" and device.index is None:
  671. device = torch.device(f"cuda:{torch.cuda.current_device()}")
  672. self.fake_device = device # type: ignore[attr-defined]
  673. self.fake_mode = fake_mode # type: ignore[attr-defined]
  674. self.constant = constant # type: ignore[attr-defined]
  675. if FakeTensorConfig.debug:
  676. import traceback
  677. self._debug_trace = traceback.extract_stack() # type: ignore[attr-defined]
  678. return self
  679. # In some circumstances, a conventional torch.Tensor constructor
  680. # will get rewritten to call into FakeTensor. We must provide an
  681. # __init__ method that can accept the Python interpreters initialization
  682. # in such a situation; we must also be able to handle direct fake
  683. # tensor construction via FakeTensor().
  684. #
  685. # In particular, the __init__ call will look funny in the following case:
  686. #
  687. # with FakeTensorMode():
  688. # x = torch.Tensor([1, 2, 3])
  689. #
  690. # this desugars into:
  691. #
  692. # with FakeTensorMode():
  693. # x = torch.Tensor.__new__([1, 2, 3])
  694. # # NB: x is a fake tensor, because of the mode!
  695. # x.__init__([1, 2, 3]) # not the normal fake tensor args!
  696. #
  697. def __init__(self, *args, **kwargs):
  698. super().__init__()
  699. @staticmethod
  700. def from_tensor(t, fake_mode):
  701. return fake_mode.from_tensor(t)
  702. # TODO: resolve error in default __repr__
  703. def __repr__(self):
  704. with in_kernel_invocation_manager(self.fake_mode):
  705. self_repr = super().__repr__()
  706. return f"FakeTensor({self_repr}, {self.fake_device})"
  707. @classmethod
  708. @count
  709. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  710. # need to handle here to avoid infinite recursion
  711. # see [in_kernel_invocation]
  712. if func == torch.ops.prim.device.default:
  713. assert len(args) == 1 and isinstance(args[0], FakeTensor)
  714. if args[0].fake_mode.in_kernel_invocation:
  715. return torch.device("meta")
  716. else:
  717. return args[0].fake_device
  718. # Because fake mode can return NotImplemented (if it sees a subclass
  719. # it doesn't know how to deal with), this test here is important
  720. # because the next dispatch after a fake mode will attempt to use
  721. # subclasses of tensors to dispatch, and any FakeTensor arguments
  722. # will be considered eligible.
  723. if any(not issubclass(t, FakeTensor) and t is not torch.Tensor for t in types):
  724. return NotImplemented
  725. fake_mode = None
  726. for arg in itertools.chain(tree_flatten(args)[0], tree_flatten(kwargs)[0]):
  727. if isinstance(arg, FakeTensor):
  728. if fake_mode is None:
  729. fake_mode = arg.fake_mode
  730. else:
  731. assert fake_mode is arg.fake_mode, "Mixing modes NYI"
  732. assert fake_mode is not None
  733. with fake_mode: # type: ignore[attr-defined]
  734. return func(*args, **kwargs)
  735. @staticmethod
  736. def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]:
  737. # Returns: (common_device, has_scalar_only_inputs)
  738. # cpu - zero-dim tensors can be called in cuda kernels,
  739. # so overwrite the common_device if it the only existing
  740. # device comes from a cpu zero-dim tensor
  741. common_device = None
  742. has_scalar_only_inputs = False
  743. is_cpu_zero_dim = None
  744. def cpu_zero_dim(t):
  745. return t.device.type == "cpu" and t.dim() == 0
  746. def merge_devices(t):
  747. nonlocal common_device
  748. nonlocal is_cpu_zero_dim
  749. if not isinstance(t, FakeTensor):
  750. return
  751. if common_device is None:
  752. common_device = t.device
  753. is_cpu_zero_dim = cpu_zero_dim(t)
  754. return
  755. t_is_cpu_zero_dim = cpu_zero_dim(t)
  756. if t.device == common_device:
  757. if is_cpu_zero_dim:
  758. is_cpu_zero_dim = t_is_cpu_zero_dim
  759. return
  760. # mismatching devices !
  761. # if current tensor is cpu 0 dim, defer to existing device
  762. if t_is_cpu_zero_dim:
  763. return
  764. # current device is from cpu 0 dim tensor, overwrite
  765. if is_cpu_zero_dim:
  766. common_device = t.device
  767. is_cpu_zero_dim = t_is_cpu_zero_dim
  768. return
  769. # mismatching devices of non-zero dim tensors, throw
  770. # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
  771. raise RuntimeError(
  772. f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
  773. )
  774. tree_map(merge_devices, args)
  775. tree_map(merge_devices, kwargs)
  776. # some functions that allow Python numbers to bind to Tensors
  777. # if we have failed to find a device, and we're running one of these operators,
  778. # we must have scalar only inputs
  779. if should_allow_numbers_as_tensors(func) and common_device is None:
  780. # ops with scalar only inputs always have result on cpu
  781. has_scalar_only_inputs = True
  782. common_device = torch.device("cpu")
  783. assert common_device is not None, f"Could not find common device for {func}"
  784. return common_device, has_scalar_only_inputs
  785. __torch_function__ = torch._C._disabled_torch_function_impl
  786. # We keep one instantiation of `fake_tensor_converter` active
  787. # for the duration of `with FakeTensorMode()`.
  788. # This allows accurate storage aliasing across invocation of
  789. # different operators. While this will keep all freshly allocated
  790. # tensors alive during `FakeTensorMode`, there will no be no
  791. # new allocations of Tensors which have non-meta storage so
  792. # memory should not significantly incraese.
  793. class FakeTensorMode(TorchDispatchMode):
  794. def __init__(
  795. self,
  796. *,
  797. allow_fallback_kernels=True,
  798. allow_non_fake_inputs=False,
  799. shape_env=None,
  800. ):
  801. self.allow_fallback_kernels = allow_fallback_kernels
  802. self.fake_tensor_converter = FakeTensorConverter()
  803. import torch._functorch.config
  804. self.allow_meta = torch._functorch.config.fake_tensor_allow_meta
  805. # A flag that controls, whether we want to invoke ops on mix of
  806. # real weights/global variables and fake inputs
  807. self.allow_non_fake_inputs = allow_non_fake_inputs
  808. # [in_kernel_invocation]
  809. # when FakeTensor is invoked in user code, .device should return
  810. # the fake_device of the tensor so that code such as as `if x.is_cuda`
  811. # or torch.zeros([10, 10], device=x.device) continues to execute as if
  812. # the FakeTensor were real. However, within kernel execution, we return
  813. # the `Meta` device because all computation within the kernels should
  814. # behave as if the Tensors are on meta devices. Kernels should allocate
  815. # new tensors on meta devices, and checks like `is_meta` should return true.
  816. # within python refs, we always return the real device by defining
  817. # the device property
  818. self.in_kernel_invocation = False
  819. self.shape_env = shape_env
  820. @count
  821. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  822. try:
  823. return self.dispatch(func, types, args, kwargs)
  824. except TypeError:
  825. log.exception("fake tensor raised TypeError")
  826. raise
  827. def dispatch(self, func, types, args=(), kwargs=None):
  828. kwargs = kwargs if kwargs else {}
  829. if func == torch.ops.prim.device.default:
  830. assert len(args) == 1 and isinstance(args[0], FakeTensor)
  831. if args[0].fake_mode.in_kernel_invocation:
  832. return torch.device("meta")
  833. else:
  834. return args[0].fake_device
  835. if log.getEffectiveLevel() <= logging.DEBUG:
  836. log.debug(
  837. f"{' ' * RECURSION_COUNT}FakeTensorMode.__torch_dispatch__: {func}"
  838. )
  839. incr = IncrementRecursionCount()
  840. # Some attribute queries that can be serviced directly
  841. # See Note [is_coalesced is dispatched]
  842. if func in {
  843. torch.ops.aten.is_coalesced.default,
  844. torch.ops.aten.dense_dim.default,
  845. torch.ops.aten.sparse_dim.default,
  846. }:
  847. # NB: no_dispatch is ok here too, this func is very simple
  848. with in_kernel_invocation_manager(self):
  849. return func(*args, **kwargs)
  850. flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
  851. flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
  852. has_symbolic_sizes = (
  853. any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors])
  854. or len(flat_symints) > 0
  855. )
  856. converter = self.fake_tensor_converter
  857. # To constant propagate through these functions:
  858. # 1, If this is a lift, the input tensor is guaranteed to be a
  859. # constant, so we keep a copy of the original argument along so
  860. # we can query it if we're asked to item() it at some later point
  861. # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div
  862. if func in self.lift_fns or (
  863. should_allow_numbers_as_tensors(func)
  864. and not has_symbolic_sizes
  865. and not flat_arg_fake_tensors
  866. ):
  867. out = func(*args, **kwargs)
  868. if self.may_turn_const(out):
  869. # NB: not in_kernel_invocation_manager because we're doing real
  870. # compute here
  871. with no_dispatch():
  872. out = out.clone()
  873. return converter(self, out, make_constant=True)
  874. # See [subclass inputs] below
  875. # NB: If you're seeing a mysterious infinite loop involving fake
  876. # tensor, it might be related to this line. Though I'm not sure
  877. # how you'll know to read this comment, as this line won't show up
  878. # in the stack trace.
  879. if self.check_for_subclass(args, kwargs):
  880. return NotImplemented
  881. # if we are in the dispatch mode, we will enter this function even if the inputs
  882. # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
  883. # and just support constructors.
  884. # this is generated from torch.tensor(), which does not use the
  885. # dispatcher, to allow wrapper subclasses to wrap the new tensor
  886. if func in self.lift_fns:
  887. assert (
  888. len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor
  889. ), f"{args} {kwargs}"
  890. return converter(self, args[0])
  891. args, kwargs = self.validate_and_convert_non_fake_tensors(
  892. func, converter, args, kwargs
  893. )
  894. # The current constant handling only support tracing systems
  895. # (aot autograd, torchdynamo) where each operation is run consecutively.
  896. # Because each operation is run in order, we can trace out and support
  897. # sequences like: x = torch.tensor(0.); y = x.add_(1)
  898. # Whenver a constant is written to but with inputs that cannot be evaluated
  899. # statically, such as random_(), we invalidate all constants that alias the input
  900. # We will rely on functionalization for use of fake tensors constants as persistent
  901. # objects on an FX Graph.
  902. # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view
  903. all_constant = all(e.constant is not None for e in flat_arg_fake_tensors)
  904. if (
  905. torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined]
  906. and torch.Tag.inplace_view not in func.tags # type: ignore[attr-defined]
  907. and all_constant
  908. and len(flat_arg_fake_tensors) != 0
  909. and not has_symbolic_sizes
  910. ):
  911. const_args, const_kwargs = pytree.tree_map_only(
  912. FakeTensor, lambda t: t.constant, (args, kwargs)
  913. )
  914. # NB: not in_kernel_invocation_manager(self) as we want to do REAL
  915. # compute
  916. with no_dispatch():
  917. out = func(*const_args, **const_kwargs)
  918. all_constant = pytree.tree_all_only(
  919. torch.Tensor, lambda t: self.may_turn_const(t), out
  920. )
  921. if all_constant:
  922. return pytree.tree_map_only(
  923. torch.Tensor,
  924. lambda t: converter(self, t, make_constant=True),
  925. out,
  926. )
  927. # we weren't able to turn outputs to constants,
  928. # so invalidate all constants that might be aliases of the outputs
  929. for ten in tree_flatten_only(torch.Tensor, out):
  930. converter.invalidate_constant_aliases(ten)
  931. # we are falling through to running non constant tensors, any input constant that
  932. # is written to must be invalidated
  933. self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  934. # Try for fastpath
  935. if has_symbolic_sizes:
  936. fast_impl = get_fast_op_impls().get(func)
  937. if fast_impl is not None:
  938. return fast_impl(self, *args, **kwargs)
  939. # If there's a Python meta, prefer that over the decomposition
  940. from torch._decomp import meta_table as meta_table
  941. if func not in meta_table and not self.cpp_meta_supports_symint(func):
  942. from torch._decomp import decomposition_table
  943. # Prefer Python decompositions over C++ ones
  944. if func in decomposition_table and (
  945. has_symbolic_sizes
  946. or (
  947. # TODO: Remove these exclusions, so that we can remove
  948. # this leg entirely
  949. torch_decomp_decompositions(func)
  950. and all(not e.is_sparse for e in flat_arg_fake_tensors)
  951. )
  952. ):
  953. with self:
  954. return decomposition_table[func](*args, **kwargs)
  955. with self:
  956. # Decomposes CompositeImplicitAutograd ops
  957. r = func.decompose(*args, **kwargs)
  958. if r is not NotImplemented:
  959. return r
  960. # prims already wrap FakeTensor inputs to FakeTensor outputs
  961. # and do device logic, we dont need do anything but run them
  962. # and ensure that Meta kernels are dispatched to (see)
  963. # Fake Tensor Dispatch Keys
  964. # TODO - we should be use the prim aten impl
  965. if "prims::" in func._schema.name and hasattr(func, "prim_meta_impl"):
  966. with self:
  967. return func.prim_meta_impl(*args, **kwargs)
  968. # special handling for funcs registered through `register_op_impl`,
  969. # e.g., manipulating args on constructor calls to construct meta tensors
  970. # and then afterwards wrapping them to a FakeTensor
  971. for run_impl_check, op_impl in op_implementations:
  972. if run_impl_check(func):
  973. op_impl_out = op_impl(self, func, *args, **kwargs)
  974. if op_impl_out != NotImplemented:
  975. return op_impl_out
  976. # run kernel registered to meta for func, which include
  977. # python meta registrations, prims, decomps, and c++ meta fns (structured kernels)
  978. try:
  979. with in_kernel_invocation_manager(self):
  980. r = func(*args, **kwargs)
  981. except NotImplementedError as not_implemented_error:
  982. # no meta kernel registered, fallback to kernel for the device
  983. if not self.allow_fallback_kernels:
  984. raise not_implemented_error
  985. return run_fallback_kernel(self, func, args, kwargs, not_implemented_error)
  986. return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs)
  987. # [subclass inputs]
  988. # Suppose we enable fake tensor mode. This means that fake tensor
  989. # mode will run first. But what if we do an operation that
  990. # involves a tensor subclass that will desugar into normal tensor
  991. # operations? Without returning NotImplemented, fake tensor mode will run first,
  992. # decide that a conversion was made (since there was a non fake
  993. # tensor argument), and report an error that converting non
  994. # fake tensor is not supported. What we actually wanted to happen
  995. # was to give the subclass a chance to figure out what it wants to
  996. # before erroring out. Returning NotImplemented here allows this.
  997. def check_for_subclass(self, args, kwargs):
  998. def check(x):
  999. return (
  1000. not isinstance(x, FakeTensor)
  1001. and type(x) is not torch.Tensor
  1002. and type(x) is not torch.nn.Parameter
  1003. )
  1004. return any([check(x) for x in tree_flatten_only(torch.Tensor, (args, kwargs))])
  1005. def validate_and_convert_non_fake_tensors(self, func, converter, args, kwargs):
  1006. """
  1007. Checks if the list of tensors are fake tensors.
  1008. If not, try to convert them to fake tensors.
  1009. """
  1010. def validate(x):
  1011. if not isinstance(x, FakeTensor):
  1012. if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined]
  1013. raise Exception(
  1014. f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {func}(*{args}, **{kwargs})"
  1015. )
  1016. if not self.allow_non_fake_inputs:
  1017. raise Exception(
  1018. f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode "
  1019. f"with 'allow_non_fake_inputs'. Found in {func}(*{args}, **{kwargs}) "
  1020. )
  1021. return converter(self, x)
  1022. return x
  1023. return tree_map_only(
  1024. torch.Tensor,
  1025. validate,
  1026. (args, kwargs),
  1027. )
  1028. def wrap_meta_outputs_with_default_device_logic(self, r, func, args, kwargs):
  1029. wrap = self.gen_wrap_fn(func, args, kwargs)
  1030. # if device is specified, use that
  1031. if kwargs.get("device", None):
  1032. return tree_map(partial(wrap, device=kwargs["device"]), r)
  1033. return tree_map(partial(wrap), r)
  1034. def gen_wrap_fn(self, func, args, kwargs):
  1035. converter = self.fake_tensor_converter
  1036. # Lazily initialized, in case there are no tensor returns
  1037. common_device = None
  1038. has_scalar_only_inputs = False
  1039. def wrap(e, device=None):
  1040. nonlocal common_device
  1041. nonlocal has_scalar_only_inputs
  1042. if (
  1043. isinstance(e, torch.Tensor)
  1044. and not isinstance(e, FakeTensor)
  1045. and converter is not None
  1046. ):
  1047. if common_device is None:
  1048. (
  1049. common_device,
  1050. has_scalar_only_inputs,
  1051. ) = FakeTensor._find_common_device(func, args, kwargs)
  1052. if has_scalar_only_inputs:
  1053. # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div,
  1054. # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details.
  1055. # We thus directly convert real tensor to fake tensor.
  1056. return converter(self, e)
  1057. else:
  1058. return converter.from_meta_and_device(
  1059. self, e, device or common_device
  1060. )
  1061. else:
  1062. return e
  1063. return wrap
  1064. def cpp_meta_supports_symint(self, func):
  1065. if torch.Tag.view_copy in func.tags: # type: ignore[attr-defined]
  1066. return True
  1067. return func in [
  1068. aten.empty_strided.default,
  1069. aten.as_strided_scatter.default,
  1070. aten.as_strided.default,
  1071. aten.as_strided_.default,
  1072. aten.zeros.default,
  1073. aten.detach.default,
  1074. aten.view_as_real.default,
  1075. aten.view_as_complex.default,
  1076. aten.set_.source_Storage_storage_offset,
  1077. aten._sparse_coo_tensor_with_dims_and_tensors.default,
  1078. ]
  1079. @property
  1080. def lift_fns(self):
  1081. return (aten.lift_fresh.default, aten.lift_fresh_copy.default)
  1082. def may_turn_const(self, t):
  1083. return (
  1084. t.numel() <= CONSTANT_NUMEL_LIMIT
  1085. and not t.is_sparse
  1086. and not isinstance(t, FakeTensor)
  1087. and not t.device.type == "meta"
  1088. )
  1089. def invalidate_written_to_constants(
  1090. self, func, flat_arg_fake_tensors, args, kwargs
  1091. ):
  1092. any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
  1093. if any_constant and get_schema_info(func).is_mutable():
  1094. schema_info = get_schema_info(func)
  1095. _, new_kwargs = normalize_function(
  1096. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1097. )
  1098. for k, v in new_kwargs.items():
  1099. k = k if (k != "input" or schema_info.has_argument(k)) else "self"
  1100. if (
  1101. isinstance(v, FakeTensor)
  1102. and schema_info.is_mutable(k)
  1103. and v.constant is not None
  1104. ):
  1105. self.fake_tensor_converter.invalidate_constant_aliases(v.constant)
  1106. def from_tensor(
  1107. self,
  1108. tensor,
  1109. static_shapes=False,
  1110. ignore_subclass=False,
  1111. source: Optional[Source] = None,
  1112. ):
  1113. if static_shapes:
  1114. return self.fake_tensor_converter(
  1115. self, tensor, ignore_subclass=ignore_subclass, source=source
  1116. )
  1117. return self.fake_tensor_converter(
  1118. self,
  1119. tensor,
  1120. shape_env=self.shape_env,
  1121. ignore_subclass=ignore_subclass,
  1122. source=source,
  1123. )
  1124. # NB: returns fake tensors
  1125. def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
  1126. # these should all be supported, just to be safe
  1127. # avoid fallback for operators which inplace modify metadata
  1128. # because the input fake tensors would be umodified
  1129. if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined]
  1130. raise orig_not_implemented_exception
  1131. inp_impls = {}
  1132. # Don't use in_kernel_invocation_manager(fake_mode) as we want to do
  1133. # REAL compute (not with meta device)
  1134. with no_dispatch():
  1135. def to_real_tensor(e):
  1136. if isinstance(e, FakeTensor):
  1137. out = torch.zeros_like(e, device=e.fake_device)
  1138. if e.is_sparse:
  1139. out._coalesced_(e.is_coalesced())
  1140. inp_impls[id(out)] = e
  1141. return out
  1142. return e
  1143. args = tree_map(to_real_tensor, args)
  1144. kwargs = tree_map(to_real_tensor, kwargs)
  1145. r = func(*args, **kwargs)
  1146. tensor_impls = set()
  1147. storages = set()
  1148. for e in tree_flatten((args, kwargs))[0]:
  1149. if isinstance(e, torch.Tensor):
  1150. if not e.is_sparse:
  1151. storages.add(e._typed_storage()._cdata)
  1152. # TODO: also check metadata change on inputs
  1153. # proper aliasing/metadata relationship between outputs and inputs will
  1154. # not be set up, bc of conversion to device, unless we can reuse an
  1155. # input impl
  1156. for e in tree_flatten(r)[0]:
  1157. if id(e) not in inp_impls and (
  1158. isinstance(e, torch.Tensor)
  1159. and not e.is_sparse
  1160. and e._typed_storage()._cdata in storages
  1161. ):
  1162. raise orig_not_implemented_exception
  1163. def map_out(e):
  1164. if isinstance(e, torch.Tensor):
  1165. if id(e) in inp_impls:
  1166. return inp_impls[id(e)]
  1167. else:
  1168. return fake_mode.fake_tensor_converter(fake_mode, e)
  1169. else:
  1170. return e
  1171. return tree_map(map_out, r)
  1172. # Just for use to allow copying a module to fake tensors,
  1173. # does not apply elsewhere
  1174. class FakeCopyMode(TorchFunctionMode):
  1175. def __init__(self, fake_mode):
  1176. self.fake_mode = fake_mode
  1177. def __torch_function__(self, func, types, args=(), kwargs=None):
  1178. kwargs = kwargs if kwargs else {}
  1179. # clone will get called in Parameter deepcopy
  1180. if func == torch._C._TensorBase.clone:
  1181. return func(
  1182. self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs
  1183. )
  1184. elif func == torch.Tensor.__deepcopy__:
  1185. assert len(args) == 2 and len(kwargs) == 0
  1186. tensor, memo = args
  1187. if id(tensor) in memo:
  1188. return memo[id(tensor)]
  1189. out = self.fake_mode.from_tensor(tensor, static_shapes=True)
  1190. memo[id(tensor)] = out
  1191. return out
  1192. else:
  1193. with torch._C.DisableTorchFunctionSubclass():
  1194. return func(*args, **kwargs)