torch.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. import logging
  2. import math
  3. import re
  4. import types
  5. from typing import Dict, List
  6. import torch._C
  7. import torch.fx
  8. import torch.nn
  9. import torch.onnx.operators
  10. from torch._dynamo.utils import get_fake_value
  11. from torch._dynamo.variables import SymNodeVariable
  12. from torch._guards import GuardsCheckpointState
  13. from .. import config, variables
  14. from ..allowed_functions import torch_get_name
  15. from ..exc import unimplemented
  16. from ..source import GetItemSource, NNModuleSource
  17. from ..utils import (
  18. check_constant_args,
  19. check_unspec_python_args,
  20. HAS_NUMPY,
  21. istype,
  22. np,
  23. product,
  24. proxy_args_kwargs,
  25. specialize_args_kwargs,
  26. tensortype_to_dtype,
  27. )
  28. from .base import VariableTracker
  29. from .lists import ListVariable, TupleVariable
  30. from .misc import AutocastModeVariable, NullContextVariable
  31. from .tensor import TensorWithTFOverrideVariable
  32. log = logging.getLogger(__name__)
  33. # TODO(voz): Maybe rename these later
  34. tensor_dunder_fns = [
  35. torch.Tensor.__rmatmul__,
  36. torch.Tensor.__rmod__,
  37. torch.Tensor.__rpow__,
  38. torch.Tensor.__rsub__,
  39. torch._C._TensorBase.__radd__,
  40. torch._C._TensorBase.__rmul__,
  41. torch._C._TensorBase.__ror__,
  42. torch._C._TensorBase.__rxor__,
  43. torch._C._TensorBase.__rand__,
  44. ]
  45. torch_special_class_types = (torch._C.Generator,)
  46. REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
  47. torch.onnx.operators.shape_as_tensor,
  48. torch._shape_as_tensor,
  49. ]
  50. constant_fold_functions = [
  51. torch._assert,
  52. torch._utils._get_device_index,
  53. torch.cuda.is_available,
  54. torch.device,
  55. torch.distributed.is_available,
  56. torch.finfo,
  57. torch.get_default_dtype,
  58. torch.iinfo,
  59. torch.is_floating_point,
  60. torch.nn.functional._Reduction.get_enum,
  61. ]
  62. if torch.distributed.is_available():
  63. constant_fold_functions.append(torch.distributed.is_initialized)
  64. # TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
  65. def remap_as_fn___radd__(*args):
  66. return torch._C._TensorBase.__radd__(*args)
  67. def remap_as_fn___rmul__(*args):
  68. return torch._C._TensorBase.__rmul__(*args)
  69. def remap_as_fn___ror__(*args):
  70. return torch._C._TensorBase.__ror__(*args)
  71. def remap_as_fn___rxor__(*args):
  72. return torch._C._TensorBase.__rxor__(*args)
  73. def remap_as_fn___rand__(*args):
  74. return torch._C._TensorBase.__rand__(*args)
  75. tensor_dunder_fns_remap = {
  76. torch._C._TensorBase.__radd__: remap_as_fn___radd__,
  77. torch._C._TensorBase.__rmul__: remap_as_fn___rmul__,
  78. torch._C._TensorBase.__ror__: remap_as_fn___ror__,
  79. torch._C._TensorBase.__rxor__: remap_as_fn___rxor__,
  80. torch._C._TensorBase.__rand__: remap_as_fn___rand__,
  81. }
  82. try:
  83. # Wed need to monkeypatch transformers here, sadly.
  84. # TODO(voz): Upstream to transformers lib
  85. import transformers
  86. def _dynamo_overriden_transformers_eq(self, other):
  87. if not hasattr(other, "__dict__"):
  88. return False
  89. return self.__dict__ == other.__dict__
  90. transformers.configuration_utils.PretrainedConfig.__eq__ = (
  91. _dynamo_overriden_transformers_eq
  92. )
  93. except ImportError:
  94. pass
  95. class TorchVariable(VariableTracker):
  96. """Points to a module or method in torch.*"""
  97. def __init__(self, value, **kwargs):
  98. super().__init__(**kwargs)
  99. if value in tensor_dunder_fns_remap:
  100. value = tensor_dunder_fns_remap[value]
  101. self.value = value
  102. # the remainder of this is just optional debug checks
  103. try:
  104. self_should_be_none = getattr(self.value, "__self__", None)
  105. except RuntimeError as e:
  106. assert "No such operator" in str(e), str(e)
  107. self_should_be_none = None
  108. # assert "_ntuple.<locals>.parse" not in str(value)
  109. if self_should_be_none is None:
  110. pass
  111. elif isinstance(self_should_be_none, types.ModuleType):
  112. # weird ones like torch.nn.functional.avg_pool2d have __self__
  113. name = self_should_be_none.__name__
  114. assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}"
  115. elif isinstance(
  116. self_should_be_none, type(torch._C._get_tracing_state.__self__)
  117. ):
  118. # some _C functions have __self__ as a null capsule
  119. pass
  120. elif isinstance(self_should_be_none, torch_special_class_types):
  121. pass
  122. else:
  123. raise AssertionError(f"{value} found with __self__ set")
  124. def __repr__(self):
  125. return f"TorchVariable({self.value})"
  126. def unique_var_name(self):
  127. name = torch_get_name(self.value, f"allowed_fn_{id(self.value)}")
  128. return "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  129. def reconstruct(self, codegen):
  130. return codegen.setup_globally_cached(self.unique_var_name(), self.value)
  131. def as_proxy(self):
  132. return self.value
  133. def python_type(self):
  134. if isinstance(self.value, (torch.Tensor, torch.nn.Module)):
  135. return type(self.value)
  136. return super().python_type()
  137. def as_python_constant(self):
  138. return self.value
  139. def can_constant_fold_through(self):
  140. if self.value in constant_fold_functions:
  141. return True
  142. return getattr(self.value, "__module__", None) == "math"
  143. def call_function(
  144. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  145. ) -> "VariableTracker":
  146. from . import (
  147. ConstantVariable,
  148. CUDAStreamContextVariable,
  149. CUDAStreamVariable,
  150. GradModeVariable,
  151. SymNodeVariable,
  152. TensorVariable,
  153. UserDefinedObjectVariable,
  154. )
  155. from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
  156. constant_args = check_constant_args(args, kwargs)
  157. unspec_python_args = check_unspec_python_args(args, kwargs)
  158. options = VariableTracker.propagate(self, args, kwargs.values())
  159. if self.value in config.constant_functions:
  160. assert not args and not kwargs
  161. return ConstantVariable(config.constant_functions[self.value], **options)
  162. elif self.can_constant_fold_through() and (constant_args or unspec_python_args):
  163. args, kwargs = specialize_args_kwargs(tx, args, kwargs)
  164. # constant fold
  165. return ConstantVariable(
  166. self.as_python_constant()(
  167. *[x.as_python_constant() for x in args],
  168. **{k: v.as_python_constant() for k, v in kwargs.items()},
  169. ),
  170. **options,
  171. )
  172. elif istype(self.value, type) and issubclass(self.value, torch.nn.Module):
  173. if self.value is torch.nn.Softmax:
  174. return self._call_softmax(tx, args, kwargs, options)
  175. if self.value is torch.nn.CrossEntropyLoss:
  176. return self._call_cross_entropy_loss(tx, args, kwargs, options)
  177. else:
  178. unimplemented(f"construct nn.Module: {self.value.__name__}")
  179. elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
  180. assert len(args) == 1
  181. if isinstance(args[0], TensorVariable) or (
  182. self.value is torch.overrides.is_tensor_like
  183. and isinstance(args[0], UserDefinedObjectVariable)
  184. and hasattr(args[0].value, "__torch_function__")
  185. ):
  186. return ConstantVariable(True, **options)
  187. else:
  188. return ConstantVariable(False, **options)
  189. elif (
  190. self.value
  191. in (
  192. torch.is_floating_point,
  193. torch.is_complex,
  194. )
  195. and isinstance(args[0], TensorVariable)
  196. and args[0].dtype is not None
  197. ):
  198. if self.value is torch.is_floating_point:
  199. return ConstantVariable(args[0].dtype.is_floating_point, **options)
  200. elif self.value is torch.is_complex:
  201. return ConstantVariable(args[0].dtype.is_complex, **options)
  202. else:
  203. raise AssertionError()
  204. elif (
  205. self.value is torch.numel
  206. and isinstance(args[0], TensorVariable)
  207. and args[0].size is not None
  208. ):
  209. return ConstantVariable(product(args[0].size), **options)
  210. elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD:
  211. assert len(args) == 1
  212. assert isinstance(args[0], TensorVariable)
  213. return args[0].call_method(tx, "size", [], {})
  214. elif self.value in (
  215. torch.nn.modules.utils._single,
  216. torch.nn.modules.utils._pair,
  217. torch.nn.modules.utils._triple,
  218. torch.nn.modules.utils._quadruple,
  219. torch.nn.modules.utils._ntuple,
  220. ):
  221. return self._call_ntuple(tx, args, kwargs, options)
  222. elif self.value is torch.no_grad:
  223. return GradModeVariable.create(tx, False, **options)
  224. elif self.value is torch.enable_grad:
  225. return GradModeVariable.create(tx, True, **options)
  226. elif self.value is torch.set_grad_enabled and len(args) == 1:
  227. return GradModeVariable.create(tx, args[0].as_python_constant(), **options)
  228. elif self.value is torch.is_grad_enabled:
  229. assert not (args or kwargs)
  230. return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
  231. GradModeVariable._guards_singleton
  232. )
  233. elif self.value is torch.cuda.stream:
  234. log.warning(
  235. "torch.cuda.stream() not fully supported, streams may be ignored"
  236. )
  237. assert len(args) == 1
  238. return CUDAStreamContextVariable.create(tx, args[0], **options)
  239. elif self.value is torch.cuda.streams.Stream:
  240. return wrap_fx_proxy_cls(
  241. CUDAStreamVariable,
  242. tx,
  243. tx.output.create_proxy(
  244. "call_function",
  245. torch.cuda.streams.Stream,
  246. (),
  247. {},
  248. ),
  249. **options,
  250. )
  251. elif not config.dynamic_shapes and self.is_dynamic_shapes(args, kwargs):
  252. unimplemented(f"dynamic shapes: {self.value.__name__}")
  253. elif len(args) > 0 and isinstance(args[0], TensorWithTFOverrideVariable):
  254. # This code block implements inlining the __torch_function__
  255. # override of a tensor.
  256. tensor_with_tf_override = args[0]
  257. # TODO(future PR): make this implement the full __torch_function__ API
  258. # instead of assuming the relevant override is in the first argument.
  259. args[0] = args[0].tensor_variable
  260. unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
  261. tx,
  262. self,
  263. tensor_with_tf_override.orig_tensor_variable_source,
  264. tensor_with_tf_override.subclass_torch_function__func,
  265. tensor_with_tf_override.subclass_type,
  266. options,
  267. args,
  268. kwargs,
  269. )
  270. # The wrapping here follows the logic in
  271. # `torch.Tensor.__torch_function__`.
  272. if self.value in torch.overrides.get_default_nowrap_functions():
  273. return unwrapped
  274. return TensorWithTFOverrideVariable(
  275. unwrapped,
  276. tensor_with_tf_override.orig_tensor_variable_source,
  277. tensor_with_tf_override.subclass_torch_function__func,
  278. tensor_with_tf_override.subclass_type,
  279. )
  280. elif self.value is torch.amp.autocast_mode.autocast:
  281. return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
  282. elif self.value in (
  283. torch.profiler.profile,
  284. torch.profiler.record_function,
  285. torch.autograd.profiler.profile,
  286. torch.autograd.profiler.record_function,
  287. ):
  288. log.warning("Profiler will be ignored")
  289. return NullContextVariable(**options)
  290. elif self.value is torch.autograd._profiler_enabled:
  291. unimplemented("torch.autograd._profiler_enabled not supported yet")
  292. elif self.value is torch.jit.annotate:
  293. assert len(args) == 2
  294. return args[1]
  295. elif self.value is torch.backends.cudnn.is_acceptable:
  296. # is_acceptable(tensor) returns true if
  297. # (a) tensor dtype/device are supported by cudnn
  298. # (b) cudnn is available
  299. # (c) some initialization has completed
  300. # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version)
  301. assert (
  302. len(args) == 1 or "tensor" in kwargs
  303. ), "Expect 1 input to cudnn.is_acceptable"
  304. tensor_variable = args[0] if len(args) > 0 else kwargs["tensor"]
  305. assert isinstance(
  306. tensor_variable, TensorVariable
  307. ), "Expect input to cudnn.is_acceptable to be a tensor"
  308. tensor_inp = torch.tensor(
  309. 0, dtype=tensor_variable.dtype, device=tensor_variable.device
  310. )
  311. return ConstantVariable(
  312. torch.backends.cudnn.is_acceptable(tensor_inp), **options
  313. )
  314. if (
  315. self.value.__name__ == "get_state"
  316. and hasattr(self.value, "__self__")
  317. and isinstance(self.value.__self__, torch._C.Generator)
  318. ):
  319. def get_state_from_generator():
  320. return self.value()
  321. return wrap_fx_proxy(
  322. tx=tx,
  323. proxy=tx.output.create_proxy(
  324. "call_function",
  325. get_state_from_generator,
  326. *proxy_args_kwargs(args, kwargs),
  327. ),
  328. example_value=self.value(),
  329. **options,
  330. )
  331. if (
  332. self.value.__name__ == "set_state"
  333. and hasattr(self.value, "__self__")
  334. and isinstance(self.value.__self__, torch._C.Generator)
  335. ) or self.value == torch.random.set_rng_state:
  336. assert len(args) == 1
  337. assert isinstance(args[0], TensorVariable)
  338. unimplemented(
  339. "TODO: make torch.random.set_rng_state work with FakeTensor/aot_autograd"
  340. )
  341. # In fake tensor case, this state doesn't matter, but
  342. # it needs to be valid to not segfault. Pull a real tensor out.
  343. # The value won't matter since we are running with fake tensors anyway, so rng doesn't matter.
  344. # However, it is imperative to record the call_function in the graph with the true args
  345. # (Not the fake example_value) - for the sake of graph correctness.
  346. if self.value == torch.random.set_rng_state:
  347. example_value = torch.random.get_rng_state()
  348. else:
  349. example_value = self.value.__self__.get_state()
  350. self.value.__module__ = self.__module__
  351. return wrap_fx_proxy(
  352. tx=tx,
  353. proxy=tx.output.create_proxy(
  354. "call_function",
  355. self.value,
  356. *proxy_args_kwargs(args, kwargs),
  357. ),
  358. example_value=example_value,
  359. **options,
  360. )
  361. elif (
  362. self.value == torch.numel
  363. and len(args) == 1
  364. and isinstance(args[0], TensorVariable)
  365. and len(kwargs) == 0
  366. ):
  367. # TODO(voz): This is rewritten as a call_method because
  368. # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not
  369. return wrap_fx_proxy(
  370. tx=tx,
  371. proxy=tx.output.create_proxy(
  372. "call_method",
  373. "numel",
  374. *proxy_args_kwargs(args, kwargs),
  375. ),
  376. **options,
  377. )
  378. elif (
  379. self.value == torch.addcdiv
  380. and len(args) == 3
  381. and "value" in kwargs
  382. and len(kwargs) == 1
  383. ):
  384. # decompose addcdiv into constituent ops, prevents a graph break due to converting
  385. # value to a scalar
  386. result = TorchVariable(torch.div, **options).call_function(tx, args[1:], {})
  387. result = TorchVariable(torch.mul, **options).call_function(
  388. tx, [result, kwargs["value"]], {}
  389. )
  390. return TorchVariable(torch.add, **options).call_function(
  391. tx, [args[0], result], {}
  392. )
  393. else:
  394. any_symints_or_symfloats = any(
  395. [isinstance(x, SymNodeVariable) for x in args]
  396. )
  397. all_ints_or_floats = all(
  398. [
  399. isinstance(
  400. x, (variables.ConstantVariable, variables.SymNodeVariable)
  401. )
  402. for x in args
  403. ]
  404. )
  405. bin_ops = {"add", "sub", "mul", "div", "sqrt"}
  406. if (
  407. getattr(self.value, "__module__", "") == "torch"
  408. and self.value.__name__ in bin_ops
  409. and any_symints_or_symfloats
  410. and all_ints_or_floats
  411. ):
  412. msg = f"""\
  413. Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
  414. To support this behavior, we need to allow const-propping tensors that store symint data.
  415. For now, dynamo will explicitly graph break when it encounters user code with this behavior.
  416. """
  417. log.warning(msg)
  418. raise unimplemented(msg)
  419. # Handle sth like torch.LongTensor(list(np.int64, np.int64, ...)),
  420. # as FX symbolic trace doesn't support numpy int/float as base types.
  421. if (
  422. HAS_NUMPY
  423. and self.value in tensortype_to_dtype
  424. and len(args) == 1
  425. and isinstance(args[0], ListVariable)
  426. and args[0].is_python_constant()
  427. ):
  428. for x in args[0].items:
  429. if isinstance(x.value, np.generic):
  430. x.value = x.value.item()
  431. if self.value == torch._C._nn.scaled_dot_product_attention:
  432. # See:[Note] SDPA_flash's meta function returns incorrect Philox seed and offset
  433. # in pytorch/torch/_meta_registrations.py
  434. all_kwargs = kwargs.copy()
  435. all_kwargs.update(
  436. dict(
  437. zip(
  438. (
  439. "query",
  440. "key",
  441. "value",
  442. "attn_mask",
  443. "dropout_p",
  444. "is_causal",
  445. ),
  446. args,
  447. )
  448. )
  449. )
  450. fake_query = all_kwargs["query"].as_proxy().node.meta["example_value"]
  451. fake_key = all_kwargs["key"].as_proxy().node.meta["example_value"]
  452. fake_value = all_kwargs["value"].as_proxy().node.meta["example_value"]
  453. fake_mask = all_kwargs.get("attn_mask")
  454. if isinstance(fake_mask, TensorVariable):
  455. fake_mask = fake_mask.as_proxy().node.meta["example_value"]
  456. else:
  457. fake_mask = None
  458. dropout_p = kwargs.get("dropout_p")
  459. dropout_p = dropout_p.value if dropout_p is not None else 0.0
  460. is_causal = kwargs.get("is_causal")
  461. is_causal = is_causal.value if is_causal is not None else False
  462. # We look through the stack to find a cuda autocast context
  463. # If we do we will convert the fake tensors to torch.float16
  464. is_cuda_autocast_context = False
  465. for block in tx.block_stack:
  466. if (
  467. isinstance(block.with_context, AutocastModeVariable)
  468. and block.with_context.target_values[0] == "cuda"
  469. ):
  470. is_cuda_autocast_context = True
  471. break
  472. if is_cuda_autocast_context and fake_query.device.type == "cuda":
  473. amp_dtype = torch.float16
  474. fake_query = fake_query.clone().to(amp_dtype)
  475. fake_key = fake_key.clone().to(amp_dtype)
  476. fake_value = fake_value.clone().to(amp_dtype)
  477. backend_choice = torch._fused_sdp_choice(
  478. fake_query, fake_key, fake_value, fake_mask, dropout_p, is_causal
  479. )
  480. if backend_choice == torch.backends.cuda.SDPBackend.FLASH_ATTENTION:
  481. if dropout_p is not None and dropout_p != 0.0:
  482. unimplemented(
  483. "FlashAttention with dropout is not supported in cuda graphs"
  484. )
  485. # TODO(voz): Replace w/ dynamic shape rewrite table.
  486. # Ideally, we would be able to do this at ctor time, but alas we need a combination
  487. # of value + args to determine this.
  488. fn_ = self.value
  489. if any([isinstance(x, SymNodeVariable) for x in args]):
  490. if self.value == math.sqrt:
  491. from torch.fx.experimental.symbolic_shapes import sym_sqrt
  492. fn_ = sym_sqrt
  493. tensor_variable = wrap_fx_proxy(
  494. tx=tx,
  495. proxy=tx.output.create_proxy(
  496. "call_function",
  497. fn_,
  498. *proxy_args_kwargs(args, kwargs),
  499. ),
  500. **options,
  501. )
  502. if "out" in kwargs and not (
  503. isinstance(kwargs["out"], variables.ConstantVariable)
  504. and kwargs["out"].as_python_constant() is None
  505. ):
  506. # out variants of torch operators like torch.sort and
  507. # torch.sigmoid mutate the tensors in the out field. Track such
  508. # tensors and rewrite the symbolic locals.
  509. if isinstance(tensor_variable, TupleVariable):
  510. assert isinstance(kwargs["out"], TupleVariable)
  511. output_tensor_names = [
  512. tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
  513. ]
  514. for idx, name in enumerate(output_tensor_names):
  515. if name in tx.symbolic_locals:
  516. tx.symbolic_locals[name] = tensor_variable.items[idx]
  517. elif isinstance(tensor_variable, TensorVariable):
  518. assert isinstance(kwargs["out"], TensorVariable)
  519. name = tx.find_symbolic_locals_name(kwargs["out"])
  520. if name in tx.symbolic_locals:
  521. tx.symbolic_locals[name] = tensor_variable
  522. else:
  523. unimplemented(f"out variant of {type(kwargs['out'])}")
  524. return tensor_variable
  525. def is_dynamic_shapes(self, args, kwargs):
  526. """Check for dynamic shapes when shape specialization is enabled"""
  527. # TODO(jansel): need to get a complete list
  528. if self.value in (
  529. torch.nonzero,
  530. torch.unique,
  531. torch.unique_consecutive,
  532. ) or self.value.__name__ in ("nms",):
  533. return True
  534. if self.value is torch.where and len(args) + len(kwargs) == 1:
  535. return True
  536. if self.value in (
  537. torch.arange,
  538. torch.repeat_interleave,
  539. ):
  540. none = variables.ConstantVariable(None)
  541. def has_non_const(it):
  542. return not all(x.is_python_constant() for x in it)
  543. def arange(start=none, end=none, step=none, **kwargs):
  544. return has_non_const([start, end, step])
  545. def repeat_interleave(input, repeats, dim=none, **kwargs):
  546. return has_non_const([repeats])
  547. return locals()[self.value.__name__](*args, **kwargs)
  548. return False
  549. def _call_softmax(self, tx, args, kwargs, options):
  550. """rewrite the pattern nn.Softmax(dim=-1)(x) to F.softmax(x, -1)"""
  551. dim = args[0] if args else kwargs.get("dim", variables.ConstantVariable(None))
  552. def fake_softmax(input):
  553. from .builder import wrap_fx_proxy
  554. return wrap_fx_proxy(
  555. tx=tx,
  556. proxy=tx.output.create_proxy(
  557. "call_function",
  558. torch.nn.functional.softmax,
  559. *proxy_args_kwargs([input, dim], {}),
  560. ),
  561. **VariableTracker.propagate([self, dim, input]),
  562. )
  563. return variables.LambdaVariable(fake_softmax, **options)
  564. def _call_cross_entropy_loss(self, tx, args, kwargs, options):
  565. """
  566. functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
  567. label_smoothing=0.0
  568. non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
  569. label_smoothing=0.0
  570. non functional loss call: input, target, optional_output
  571. """
  572. from . import ConstantVariable
  573. def normalize_args(
  574. weight=ConstantVariable(None),
  575. size_average=ConstantVariable(None),
  576. ignore_index=ConstantVariable(-100),
  577. reduce=ConstantVariable(None),
  578. reduction=ConstantVariable("mean"),
  579. label_smoothing=ConstantVariable(0.0),
  580. ):
  581. return (
  582. weight,
  583. size_average,
  584. ignore_index,
  585. reduce,
  586. reduction,
  587. label_smoothing,
  588. )
  589. (
  590. weight,
  591. size_average,
  592. ignore_index,
  593. reduce_arg,
  594. reduction,
  595. label_smoothing,
  596. ) = normalize_args(*args, **kwargs)
  597. def fake_cross_entropy_loss(input, target):
  598. from .builder import wrap_fx_proxy
  599. return wrap_fx_proxy(
  600. tx=tx,
  601. proxy=tx.output.create_proxy(
  602. "call_function",
  603. torch.nn.functional.cross_entropy,
  604. *proxy_args_kwargs(
  605. [
  606. input,
  607. target,
  608. weight,
  609. size_average,
  610. ignore_index,
  611. reduce_arg,
  612. reduction,
  613. label_smoothing,
  614. ],
  615. {},
  616. ),
  617. ),
  618. **VariableTracker.propagate(
  619. [
  620. self,
  621. weight,
  622. size_average,
  623. ignore_index,
  624. reduce_arg,
  625. reduction,
  626. label_smoothing,
  627. input,
  628. target,
  629. ]
  630. ),
  631. )
  632. return variables.LambdaVariable(fake_cross_entropy_loss, **options)
  633. def _call_ntuple(self, tx, args, kwargs, options):
  634. """inline behavior of torch.nn.modules.utils._ntuple"""
  635. if self.value is torch.nn.modules.utils._ntuple:
  636. count = args[0].as_python_constant()
  637. else:
  638. count = self.value.__closure__[0].cell_contents
  639. assert isinstance(count, int)
  640. def handle_ntuple(value):
  641. if value.has_unpack_var_sequence(tx):
  642. return variables.TupleVariable(
  643. list(value.unpack_var_sequence(tx)),
  644. **VariableTracker.propagate(self, value, args, kwargs.values()),
  645. )
  646. elif value.is_python_constant():
  647. # constant prop through it
  648. return variables.ConstantVariable(
  649. torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
  650. **VariableTracker.propagate(self, value, args, kwargs.values()),
  651. )
  652. else:
  653. unimplemented(f"torch.nn.modules.utils._ntuple({value})")
  654. if self.value is torch.nn.modules.utils._ntuple:
  655. return variables.LambdaVariable(handle_ntuple, **options)
  656. else:
  657. return handle_ntuple(args[0])
  658. class TorchPyOperator(VariableTracker):
  659. def __init__(self, value, **kwargs):
  660. super().__init__(**kwargs)
  661. self.value = value
  662. def call_function(
  663. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  664. ) -> "VariableTracker":
  665. from . import (
  666. ListVariable,
  667. NestedUserFunctionVariable,
  668. TensorVariable,
  669. UserFunctionVariable,
  670. )
  671. from .builder import wrap_fx_proxy
  672. assert kwargs is None or len(kwargs) == 0, "kwargs are not supported, yet"
  673. def make_attr(name):
  674. node = tx.output.create_proxy(
  675. "get_attr",
  676. name,
  677. (),
  678. {},
  679. )
  680. return node
  681. def add_subgraph(name, gm):
  682. next_name = None
  683. i = 0
  684. while not next_name:
  685. candidate = f"cond_{name}_{i}"
  686. if candidate in tx.output.nn_modules:
  687. i += 1
  688. else:
  689. next_name = candidate
  690. gm.__name__ = next_name
  691. src = NNModuleSource(GetItemSource(self.source, next_name))
  692. gm.torchdynamo_force_dynamic = False
  693. tx.output.register_attr_or_module(gm, next_name, source=src)
  694. return next_name
  695. def get_comparable_state(state):
  696. # Nub out bits of state that we don't require to be
  697. # equal
  698. return state._replace(
  699. output=state.output._replace(
  700. guard_state=GuardsCheckpointState(set()),
  701. nn_modules=None,
  702. # Timestamp is monotonically increasing so we don't
  703. # care about divergence
  704. timestamp=0,
  705. # Unused in branches
  706. graphargs=[],
  707. )
  708. )
  709. def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint):
  710. # Setup the subgraph we're going to capture into
  711. tx.output.graph = torch.fx.Graph()
  712. tx.output.graphargs = []
  713. tx.output.name_to_input.clear()
  714. args = []
  715. # One argument to graph per sub_args
  716. for a in sub_args:
  717. if isinstance(a, TensorVariable):
  718. tx.output.create_graph_input(a.as_proxy().node.name)
  719. args.append(a)
  720. else:
  721. # call_function() needs a TensorVariable, therefore we construct
  722. # one with inner graph proxy.
  723. assert isinstance(a, torch.Tensor)
  724. proxy = tx.output.create_graph_input("arg")
  725. args.append(wrap_fx_proxy(tx=tx, proxy=proxy, example_value=a))
  726. # NB: we don't bother populating graphargs, as
  727. # they won't actually get used by anything
  728. output = f.call_function(tx, args, {})
  729. # Register output to graph
  730. # Modeled off of compile_and_call_fx_graph
  731. # TODO: support non single Tensor output
  732. assert isinstance(output, TensorVariable)
  733. tx.output.guards.update(output.guards)
  734. tx.output.create_node(
  735. "output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
  736. )
  737. tx.output.side_effects.prune_dead_object_new(tx)
  738. state = tx.copy_graphstate()
  739. guards = state.output.guards
  740. nn_modules = state.output.nn_modules
  741. comparable_state = get_comparable_state(state)
  742. graph = tx.output.graph
  743. tx.output.graph = graph_checkpoint
  744. tx.restore_graphstate(checkpoint)
  745. return output, graph, guards, nn_modules, comparable_state
  746. if self.value.__name__ == "cond":
  747. # TODO(voz): Support fake tensor dispatch for recursive
  748. # ops - see torch/dispatch/_dispatcher.py
  749. assert len(args) == 4
  750. assert type(args[0]) in (TensorVariable, SymNodeVariable), str(
  751. type(args[0])
  752. ) # predicate
  753. assert isinstance(
  754. args[1], (UserFunctionVariable, NestedUserFunctionVariable)
  755. ), str(
  756. type(args[1])
  757. ) # true_fn
  758. assert isinstance(
  759. args[2], (UserFunctionVariable, NestedUserFunctionVariable)
  760. ), str(
  761. type(args[2])
  762. ) # false_fn
  763. assert type(args[3]) is ListVariable, str(type(args[3])) # args
  764. # Our strategy for tracing the true/false branches of cond
  765. # are to checkpoint our graphstate, run the true branch,
  766. # roll it back to the checkpoint, and run the false
  767. # branch, and then merge the graphstates. Well, perhaps
  768. # "merge" is too strong a word: we mostly assert that
  769. # the resulting graphstates have to be the same.
  770. #
  771. # We only permit guards to diverge (we union the guards from
  772. # both branches). In particular, this means that side
  773. # effects are NOT permitted inside true/false branches; this
  774. # would be difficult to implement, because of the path
  775. # explosion problem.
  776. graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
  777. sub_args = args[3].unpack_var_sequence(tx)
  778. def speculate_branch(branch):
  779. # NB: 0 is predicate
  780. ix = 1 if branch else 2
  781. return speculate_subgraph(
  782. args[ix], sub_args, graph_checkpoint, checkpoint
  783. )
  784. (
  785. true_r,
  786. true_graph,
  787. true_guards,
  788. true_nn_modules,
  789. true_cmp,
  790. ) = speculate_branch(True)
  791. (
  792. false_r,
  793. false_graph,
  794. false_guards,
  795. false_nn_modules,
  796. false_cmp,
  797. ) = speculate_branch(False)
  798. if true_cmp != false_cmp:
  799. unimplemented(true_cmp.diff(false_cmp))
  800. # Add guards
  801. tx.output.tracing_context.guards_context.dynamo_guards |= false_guards
  802. tx.output.tracing_context.guards_context.dynamo_guards |= true_guards
  803. true_name = add_subgraph(
  804. "true", torch.fx.GraphModule(true_nn_modules, true_graph)
  805. )
  806. false_name = add_subgraph(
  807. "false", torch.fx.GraphModule(false_nn_modules, false_graph)
  808. )
  809. # Apply side effects (guaranteed to be equal)
  810. tx.output.side_effects = true_cmp.output.side_effects
  811. true_node = make_attr(true_name)
  812. false_node = make_attr(false_name)
  813. p_args = (
  814. args[0].as_proxy(),
  815. true_node,
  816. false_node,
  817. [a.as_proxy() for a in sub_args],
  818. )
  819. # TODO: assert that the true/false return values are
  820. # consistent
  821. example_value = true_r.as_proxy().node.meta["example_value"]
  822. elif self.value.__name__ == "map":
  823. assert type(args[0]) in (UserFunctionVariable, NestedUserFunctionVariable)
  824. assert type(args[1]) is TensorVariable
  825. sample_shape = args[1].get_real_value().size()
  826. if len(sample_shape) < 1 or sample_shape[0] == 0:
  827. unimplemented(
  828. "map() operator doesn't support scalar or zero-sized tensors during tracing."
  829. )
  830. checkpoint = tx.copy_graphstate()
  831. # To get the example output from map() we will need to prodive at least one sample to
  832. # the loop body. In our case we will always use xs[0], and our map() won't support zero
  833. # sized tensor during tracing.
  834. (
  835. body_r,
  836. body_graph,
  837. body_guards,
  838. body_nn_modules,
  839. body_cmp,
  840. ) = speculate_subgraph(
  841. args[0],
  842. [
  843. get_fake_value(args[1].as_proxy().node, tx)[0],
  844. *args[2:],
  845. ],
  846. tx.output.graph,
  847. checkpoint,
  848. )
  849. # We don't support side effects inside a map loop body for simplicity.
  850. parent_cmp = get_comparable_state(checkpoint)
  851. if parent_cmp != body_cmp:
  852. diff = parent_cmp.diff(body_cmp)
  853. raise unimplemented(
  854. f"Graph state change detected in map() loop body. Diagnostics: {diff}"
  855. )
  856. # Add guards
  857. tx.output.tracing_context.guards_context.dynamo_guards |= body_guards
  858. body_name = add_subgraph(
  859. "body", torch.fx.GraphModule(body_nn_modules, body_graph)
  860. )
  861. body_node = make_attr(body_name)
  862. p_args = (body_node, *(arg.as_proxy() for arg in args[1:]))
  863. r = body_r.as_proxy().node.meta["example_value"]
  864. example_value = r.new_empty(
  865. [get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape]
  866. )
  867. else:
  868. unimplemented(f"PyOperator {self.value.__name__}")
  869. # Store the invocation as a call
  870. return wrap_fx_proxy(
  871. tx=tx,
  872. proxy=tx.output.create_proxy(
  873. "call_function",
  874. self.value,
  875. args=tuple(p_args),
  876. kwargs={},
  877. ),
  878. example_value=example_value,
  879. )