tensor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. import inspect
  2. import itertools
  3. import operator
  4. import types
  5. from typing import Dict, List
  6. import torch.fx
  7. import torch.random
  8. from torch.fx.experimental.symbolic_shapes import guard_scalar
  9. from .. import config, variables
  10. from ..exc import unimplemented
  11. from ..guards import GuardBuilder
  12. from ..source import AttrSource
  13. from ..utils import (
  14. fqn,
  15. get_fake_value,
  16. get_real_value,
  17. HAS_NUMPY,
  18. np,
  19. product,
  20. proxy_args_kwargs,
  21. tensortype_to_dtype,
  22. )
  23. from .base import VariableTracker
  24. from .constant import ConstantVariable
  25. from .lists import ShapeVariable, SizeVariable
  26. supported_tensor_comparison_ops = {
  27. ">": operator.gt,
  28. "<": operator.lt,
  29. ">=": operator.ge,
  30. "<=": operator.le,
  31. "==": operator.eq,
  32. "!=": operator.ne,
  33. }
  34. supported_const_comparison_ops = {
  35. "is": operator.is_,
  36. "is not": operator.is_not,
  37. "==": operator.eq,
  38. "!=": operator.ne,
  39. }
  40. class TensorVariable(VariableTracker):
  41. """A torch.Tensor input or an intermediate value in the FX graph"""
  42. _nonvar_fields = [
  43. "proxy",
  44. "dtype",
  45. "device",
  46. "layout",
  47. "ndim",
  48. "size",
  49. "stride",
  50. "requires_grad",
  51. "is_quantized",
  52. "is_contiguous",
  53. ]
  54. def get_real_value(self):
  55. """
  56. Get the actual value represented by this variable if computation is run
  57. using the user-provided inputs.
  58. NOTE: this runs actual tensor computation and may be
  59. slow and memory-intensive.
  60. """
  61. return get_real_value(self.proxy.node, self.proxy.tracer)
  62. def __init__(
  63. self,
  64. proxy: torch.fx.Proxy,
  65. dtype=None,
  66. device=None,
  67. layout=None,
  68. ndim=None,
  69. size=None,
  70. stride=None,
  71. requires_grad=None,
  72. is_quantized=None,
  73. is_contiguous=None,
  74. is_sparse=None,
  75. class_type=torch.Tensor,
  76. specialized_value=None,
  77. **kwargs,
  78. ):
  79. super().__init__(**kwargs)
  80. self.proxy = proxy
  81. self.dtype = dtype
  82. self.device = device
  83. self.layout = layout
  84. self.ndim = ndim
  85. self.size = size
  86. self.stride = stride
  87. self.requires_grad = requires_grad
  88. self.is_quantized = is_quantized
  89. self.is_contiguous = is_contiguous
  90. self.is_sparse = is_sparse
  91. self.class_type = class_type
  92. self.specialized_value = specialized_value
  93. def as_proxy(self):
  94. return self.proxy
  95. def python_type(self):
  96. return self.class_type
  97. def call_isinstance(self, tensor_type):
  98. def check_type(ty):
  99. if ty not in tensortype_to_dtype:
  100. return issubclass(self.python_type(), ty)
  101. dtypes = tensortype_to_dtype[ty]
  102. return self.dtype in dtypes
  103. if type(tensor_type) is tuple:
  104. return any([check_type(ty) for ty in tensor_type])
  105. else:
  106. return check_type(tensor_type)
  107. @staticmethod
  108. def specialize(value: torch.Tensor):
  109. props = {
  110. "dtype": value.dtype,
  111. "device": value.device,
  112. "layout": value.layout,
  113. "ndim": int(value.ndim),
  114. "requires_grad": value.requires_grad,
  115. "is_quantized": value.is_quantized,
  116. "is_sparse": value.is_sparse,
  117. "class_type": type(value),
  118. }
  119. if not config.dynamic_shapes:
  120. props["size"] = tuple(value.size())
  121. props["stride"] = tuple(value.stride())
  122. props["is_contiguous"] = tuple(
  123. [
  124. x
  125. for x in torch._prims_common._memory_formats
  126. if value.is_contiguous(memory_format=x)
  127. ]
  128. )
  129. return props
  130. def var_getattr(self, tx, name):
  131. from . import ConstantVariable, TorchVariable
  132. result = None
  133. options = VariableTracker.propagate(self)
  134. if name == "ndim" and self.ndim is not None:
  135. result = ConstantVariable(self.ndim, **options)
  136. elif name == "dtype" and self.dtype is not None:
  137. result = TorchVariable(self.dtype, **options)
  138. elif name == "device" and self.device is not None:
  139. result = TorchVariable(self.device, **options)
  140. elif name == "layout" and self.layout is not None:
  141. result = TorchVariable(self.layout, **options)
  142. elif name == "is_cuda" and self.device is not None:
  143. result = ConstantVariable(self.device.type == "cuda", **options)
  144. elif name == "shape" and self.size is not None:
  145. sizes = [variables.ConstantVariable(x) for x in self.size]
  146. result = ShapeVariable(sizes, **options)
  147. elif name == "requires_grad" and self.requires_grad is not None:
  148. result = ConstantVariable(self.requires_grad, **options)
  149. elif name == "is_quantized" and self.is_quantized is not None:
  150. result = ConstantVariable(self.is_quantized, **options)
  151. elif name == "is_sparse" and self.is_sparse is not None:
  152. result = ConstantVariable(self.is_sparse, **options)
  153. elif name == "shape" and self.size is None:
  154. result = self.call_method(tx, "size", [], {})
  155. elif name == "ndim" and self.ndim is None:
  156. result = self.call_method(tx, "dim", [], {})
  157. elif name == "data":
  158. result = self.call_method(tx, "detach", [], {})
  159. if name == "__class__":
  160. return TorchVariable(self.python_type(), **options)
  161. # Add a guard for type matching, these guards are checked before tensor guards
  162. # In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
  163. # <tensor> is later changed to another type
  164. if result is not None and self.source is not None:
  165. result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
  166. # For attributes (not methods) that were not caught in the special handling above,
  167. # (e.g. tensor.real), we handle these generically, assuming that the output type is
  168. # a tensor.
  169. if result is None:
  170. def try_generic_attr_handling():
  171. from .builder import wrap_fx_proxy
  172. from .misc import GetAttrVariable
  173. try:
  174. static_attr = inspect.getattr_static(torch.Tensor, name)
  175. except AttributeError:
  176. return None
  177. # Make sure this is an attribute, not a method.
  178. # type(torch.Tensor.H) should be "getset_descriptor"
  179. # This is a because of CPython implementation, see THPVariableType:
  180. # these attributes are implemented under tp_getset, which appear
  181. # as `getset_descriptor`s, (compared to, say, methods which appear
  182. # as `method_descriptor`s)
  183. if type(static_attr) != types.GetSetDescriptorType:
  184. return None
  185. return wrap_fx_proxy(
  186. tx=tx,
  187. proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name),
  188. **options,
  189. )
  190. result = try_generic_attr_handling()
  191. if result is None:
  192. raise NotImplementedError()
  193. return result
  194. def has_unpack_var_sequence(self, tx):
  195. return (self.size is not None and len(self.size) > 0) or (
  196. self.size is None and config.dynamic_shapes
  197. )
  198. def unpack_var_sequence(self, tx, idxes=None):
  199. from .builder import wrap_fx_proxy
  200. options = VariableTracker.propagate(self)
  201. if idxes is None:
  202. if self.size:
  203. length = self.size[0]
  204. else:
  205. dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
  206. assert isinstance(dyn_length, SymNodeVariable)
  207. length = dyn_length.evaluate_expr(tx.output)
  208. idxes = range(length)
  209. return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
  210. def call_method(
  211. self,
  212. tx,
  213. name,
  214. args: "List[VariableTracker]",
  215. kwargs: "Dict[str, VariableTracker]",
  216. ) -> "VariableTracker":
  217. from . import ConstantVariable, TorchVariable, TupleVariable
  218. from .builder import wrap_fx_proxy
  219. kwargs = dict(kwargs)
  220. options = VariableTracker.propagate(self, args, kwargs.values())
  221. if name == "stride" and self.stride is not None:
  222. constant_result = ConstantVariable(self.stride, **options)
  223. elif name == "size" and self.size is not None:
  224. sizes = [variables.ConstantVariable(x) for x in self.size]
  225. constant_result = SizeVariable(sizes, **options)
  226. elif name == "size" and self.size is None and config.dynamic_shapes:
  227. return wrap_fx_proxy(
  228. tx,
  229. tx.output.create_proxy(
  230. "call_method",
  231. name,
  232. *proxy_args_kwargs([self] + list(args), kwargs),
  233. ),
  234. **options,
  235. )
  236. elif name in ("numel", "nelement") and self.size is not None:
  237. constant_result = ConstantVariable(product(self.size), **options)
  238. elif name in ("ndimension", "dim") and self.ndim is not None:
  239. constant_result = ConstantVariable(self.ndim, **options)
  240. elif name == "is_floating_point" and self.dtype is not None:
  241. constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
  242. elif name == "is_contiguous" and self.is_contiguous is not None:
  243. if "memory_format" in kwargs:
  244. memory_format = kwargs.pop("memory_format").as_python_constant()
  245. else:
  246. memory_format = torch.contiguous_format
  247. constant_result = ConstantVariable(
  248. memory_format in self.is_contiguous, **options
  249. )
  250. elif (
  251. name == "type"
  252. and self.dtype is not None
  253. and len(args) == 0
  254. and isinstance(self.device, torch.device)
  255. ):
  256. tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
  257. 0
  258. ]
  259. if self.device.type == "cuda":
  260. constant_result = ConstantVariable(
  261. f"torch.cuda.{tensortype.__name__}", **options
  262. )
  263. else:
  264. constant_result = ConstantVariable(
  265. f"torch.{tensortype.__name__}", **options
  266. )
  267. elif (
  268. name == "type"
  269. and len(args) == 1
  270. and fqn(type(args[0].as_python_constant())) == "torch.tensortype"
  271. ):
  272. # torch.FloatTensor, etc. are all of type "torch.tensortype".
  273. # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
  274. # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
  275. tensor_type = args[0].as_python_constant()
  276. tensor_type_const = ConstantVariable(fqn(tensor_type), **options)
  277. return wrap_fx_proxy(
  278. tx,
  279. tx.output.create_proxy(
  280. "call_method",
  281. name,
  282. *proxy_args_kwargs([self, tensor_type_const], kwargs),
  283. ),
  284. **options,
  285. )
  286. elif name == "get_device" and isinstance(self.device, torch.device):
  287. index = self.device.index if self.device.type != "cpu" else -1
  288. constant_result = ConstantVariable(index, **options)
  289. else:
  290. constant_result = None
  291. if constant_result:
  292. assert not kwargs, f"Tensor.{name}() unhandled kwargs"
  293. if len(args) == 1:
  294. return constant_result.getitem_const(args[0])
  295. elif args:
  296. return TupleVariable(
  297. [constant_result.getitem_const(a) for a in args], **options
  298. )
  299. return constant_result
  300. elif (
  301. name == "repeat"
  302. and not all(
  303. x.is_python_constant() for x in itertools.chain(args, kwargs.values())
  304. )
  305. and not config.dynamic_shapes
  306. ):
  307. unimplemented("dynamic Tensor.repeat")
  308. elif name in ("tolist", "numpy", "backward", "data_ptr"):
  309. unimplemented(f"Tensor.{name}")
  310. elif name == "nonzero" and not config.dynamic_shapes:
  311. unimplemented(f"Tensor.{name}")
  312. elif name == "item" and not config.capture_scalar_outputs:
  313. unimplemented(f"Tensor.{name}")
  314. elif (
  315. name == "item"
  316. and config.capture_scalar_outputs
  317. and not config.dynamic_shapes
  318. ):
  319. raise AssertionError(
  320. "To capture_scalar_outputs, you must also set dynamic_shapes = True"
  321. )
  322. elif name == "__len__":
  323. return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
  324. elif name == "__setitem__":
  325. tx.output.guards.update(options["guards"])
  326. tx.output.create_proxy(
  327. "call_function",
  328. operator.setitem,
  329. *proxy_args_kwargs([self] + list(args), kwargs),
  330. )
  331. return ConstantVariable(None, **options)
  332. elif name in ("resize_", "resize_as_"):
  333. if "memory_format" in kwargs:
  334. memory_format = kwargs["memory_format"].as_python_constant()
  335. else:
  336. memory_format = torch.contiguous_format
  337. if name == "resize_":
  338. self.size = args[0].as_python_constant()
  339. self.is_contiguous = (memory_format,)
  340. else:
  341. assert isinstance(args[0], TensorVariable)
  342. if self.size and args[0].size:
  343. if (
  344. self.size == args[0].size
  345. or memory_format is torch.preserve_format
  346. ):
  347. self.is_contiguous = args[0].is_contiguous
  348. else:
  349. self.size = args[0].size
  350. self.stride = args[0].stride
  351. self.ndim = args[0].ndim
  352. self.is_contiguous = (memory_format,)
  353. return wrap_fx_proxy(
  354. tx,
  355. tx.output.create_proxy(
  356. "call_method",
  357. name,
  358. *proxy_args_kwargs([self] + list(args), kwargs),
  359. ),
  360. **options,
  361. )
  362. elif (
  363. name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
  364. ):
  365. result = TorchVariable(torch.mul, **options).call_function(
  366. tx, args + [kwargs["alpha"]], {}
  367. )
  368. return self.call_method(tx, "add_", [result], {})
  369. elif (
  370. name == "addcdiv_"
  371. and len(args) == 2
  372. and len(kwargs) == 1
  373. and "value" in kwargs
  374. ):
  375. result = TorchVariable(torch.div, **options).call_function(tx, args, {})
  376. result = TorchVariable(torch.mul, **options).call_function(
  377. tx, [result, kwargs["value"]], {}
  378. )
  379. return self.call_method(tx, "add_", [result], {})
  380. else:
  381. # Convert x.new(torch.Size) into x.new_empty(torch.Size),
  382. # as Tensor.new acts differently with a Size input versus a tuple input.
  383. if (
  384. name == "new"
  385. and len(args) == 1
  386. and isinstance(args[0], (SizeVariable, ShapeVariable))
  387. and not config.dynamic_shapes
  388. ):
  389. name = "new_empty"
  390. return wrap_fx_proxy(
  391. tx,
  392. tx.output.create_proxy(
  393. "call_method",
  394. name,
  395. *proxy_args_kwargs([self] + list(args), kwargs),
  396. ),
  397. **options,
  398. )
  399. class SymNodeVariable(VariableTracker):
  400. """
  401. Represents a symbolic size, e.g., as returned by tensor.size(0)
  402. """
  403. @classmethod
  404. def create(cls, tx, proxy, sym_num, **options):
  405. if "example_value" in proxy.node.meta:
  406. assert proxy.node.meta["example_value"] == sym_num
  407. if sym_num is None:
  408. sym_num = get_fake_value(proxy.node, tx)
  409. proxy.node.meta["example_value"] = sym_num
  410. return SymNodeVariable(proxy, sym_num, **options)
  411. def __init__(self, proxy, sym_num, **kwargs):
  412. super().__init__(**kwargs)
  413. self.proxy = proxy
  414. self.sym_num = sym_num
  415. def python_type(self):
  416. return type(self.sym_num)
  417. def unpack_var_sequence(self, tx):
  418. super().unpack_var_sequence(tx)
  419. def as_proxy(self):
  420. return self.proxy
  421. def evaluate_expr(self, output_graph):
  422. return guard_scalar(self.sym_num)
  423. def call_method(
  424. self,
  425. tx,
  426. name,
  427. args: "List[VariableTracker]",
  428. kwargs: "Dict[str, VariableTracker]",
  429. ) -> "VariableTracker":
  430. from .builder import wrap_fx_proxy
  431. options = VariableTracker.propagate(self, args, kwargs.values())
  432. return wrap_fx_proxy(
  433. tx,
  434. tx.output.create_proxy(
  435. "call_method",
  436. name,
  437. *proxy_args_kwargs([self] + list(args), kwargs),
  438. ),
  439. **options,
  440. )
  441. class TensorWithTFOverrideVariable(VariableTracker):
  442. """
  443. Represents a tensor subclass instance with a __torch_function__ override.
  444. """
  445. def __init__(
  446. self,
  447. tensor_variable,
  448. orig_tensor_variable_source,
  449. subclass_torch_function__func,
  450. subclass_type,
  451. **kwargs,
  452. ):
  453. super().__init__(**kwargs)
  454. self.tensor_variable = tensor_variable
  455. self.orig_tensor_variable_source = orig_tensor_variable_source
  456. self.subclass_torch_function__func = subclass_torch_function__func
  457. self.subclass_type = subclass_type
  458. def call_method(
  459. self,
  460. tx,
  461. name,
  462. args: "List[VariableTracker]",
  463. kwargs: "Dict[str, VariableTracker]",
  464. ) -> "VariableTracker":
  465. # This code block implements inlining the __torch_function__ override
  466. # of `call_method`.
  467. from . import GetAttrVariable
  468. options = VariableTracker.propagate(self, args, kwargs.values())
  469. # insert unwrapped version of self as the first argument
  470. # TODO: This is wrong! When you call the internal __torch_function__,
  471. # you still get the wrapped version of self, and if you call functions
  472. # inside __torch_function__, they should come back here. If we unwrap
  473. # the tensor immediately, that will not happen.
  474. # See https://github.com/pytorch/torchdynamo/issues/1951
  475. args = list(args)
  476. args.insert(0, self.tensor_variable)
  477. func_var = GetAttrVariable(self.tensor_variable, name)
  478. unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
  479. tx,
  480. func_var,
  481. self.orig_tensor_variable_source,
  482. self.subclass_torch_function__func,
  483. self.subclass_type,
  484. options,
  485. args,
  486. kwargs,
  487. )
  488. # TODO(future PR): implement rewrapping conditional on method presence
  489. # in `torch.overrides.get_default_nowrap_function()`. It's unclear how
  490. # to do this easily in the current codebase since the resolution of
  491. # `GetAttrVariable` depends on the type of the underlying object.
  492. return TensorWithTFOverrideVariable(
  493. unwrapped,
  494. self.orig_tensor_variable_source,
  495. self.subclass_torch_function__func,
  496. self.subclass_type,
  497. )
  498. @staticmethod
  499. def inline_torch_function_unwrapped(
  500. tx,
  501. original_func_var,
  502. tensor_with_tf_override_source,
  503. tf_func,
  504. subclass_type,
  505. options,
  506. args,
  507. kwargs,
  508. ):
  509. """
  510. This function inlines the `__torch_function__` override for `original_func_var`.
  511. For example, if the user code is
  512. x1 = torch.sigmoid(x0)
  513. And `x0` has an override, then:
  514. * `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
  515. * `tensor_with_tf_override_source` will be the `Source` object from
  516. the original tensor override instance in the beginning of the program
  517. * `tf_func` will be the custom `__torch_function__` function
  518. * `subclass_type` will be `type(x0)`
  519. The caller is expected to properly massage args and kwargs before
  520. passing them into this function.
  521. The caller is responsible for wrapping the return value, if needed.
  522. """
  523. from . import UserDefinedClassVariable
  524. from .builder import TupleVariable, VariableBuilder
  525. source = AttrSource(
  526. AttrSource(tensor_with_tf_override_source, "__torch_function__"),
  527. "__func__",
  528. )
  529. tf_func_var = VariableBuilder(tx, source)(tf_func)
  530. type_var = UserDefinedClassVariable(subclass_type, **options)
  531. # signature:
  532. # def __torch_function__(cls, func, types, args=(), kwargs=None):
  533. tf_args = (
  534. type_var, # cls
  535. original_func_var, # func
  536. (type_var,), # types
  537. TupleVariable(args), # args
  538. kwargs, # kwargs
  539. )
  540. # Disable __torch_function__ here to prevent the clone of the
  541. # example tensor from going into the override.
  542. with torch._C.DisableTorchFunctionSubclass():
  543. return tx.inline_user_function_return(tf_func_var, tf_args, {})
  544. class UnspecializedPythonVariable(TensorVariable):
  545. """
  546. This is a 1-element tensor represents unspecialized python float/int.
  547. """
  548. def __init__(self, proxy: torch.fx.Proxy, **kwargs):
  549. raw_value = kwargs.pop("raw_value", None)
  550. if HAS_NUMPY and isinstance(raw_value, np.number):
  551. raw_values = raw_value.item()
  552. need_unwrap = kwargs.pop("need_unwrap", True)
  553. super().__init__(proxy, **kwargs)
  554. self.raw_value = raw_value
  555. self.need_unwrap = need_unwrap
  556. @classmethod
  557. def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
  558. # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
  559. return UnspecializedPythonVariable(
  560. **dict(tensor_variable.__dict__),
  561. raw_value=raw_value,
  562. need_unwrap=need_unwrap,
  563. )
  564. def as_specialized(self, tx):
  565. for graph_arg in tx.output.graphargs:
  566. if graph_arg.source is self.source:
  567. graph_arg.erase()
  568. for g in self.guards:
  569. if g.is_volatile:
  570. g.create_fn = GuardBuilder.CONSTANT_MATCH
  571. return ConstantVariable(value=self.raw_value, guards=self.guards)
  572. class FakeItemVariable(TensorVariable):
  573. """An unspecialized python variable which prevents access to the underlying raw value.
  574. This is needed if item is called on a FakeTensor."""
  575. def __init__(self, proxy: torch.fx.Proxy, **kwargs):
  576. need_unwrap = kwargs.pop("need_unwrap", False)
  577. super().__init__(proxy, **kwargs)
  578. self.need_unwrap = need_unwrap
  579. @classmethod
  580. def from_tensor_variable(cls, tensor_variable):
  581. return FakeItemVariable(**dict(tensor_variable.__dict__))