aot_autograd.py 123 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830
  1. import collections
  2. import dataclasses
  3. import itertools
  4. import logging
  5. import warnings
  6. from contextlib import contextmanager, nullcontext
  7. from dataclasses import dataclass
  8. from enum import Enum
  9. from functools import partial, wraps
  10. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  11. from functorch import make_fx
  12. import torch
  13. import torch.fx.traceback as fx_traceback
  14. import torch.nn as nn
  15. import torch.utils._pytree as pytree
  16. import torch.utils.dlpack
  17. from torch import Tensor
  18. from torch._dispatch.python import enable_python_dispatcher
  19. from torch._dynamo.utils import dynamo_timed
  20. from torch._subclasses import CrossRefFakeMode, FakeTensor, FakeTensorMode
  21. from torch.fx import immutable_collections, Interpreter
  22. from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
  23. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  24. from torch.multiprocessing.reductions import StorageWeakRef
  25. from torch.nn.utils import stateless
  26. from . import config
  27. from .partitioners import default_partition
  28. from torch._guards import TracingContext, DuplicateInputs
  29. log = logging.getLogger(__name__)
  30. MutationType = Enum(
  31. "MutationType", ("none", "metadata_only", "data", "data_and_metadata")
  32. )
  33. OutputType = Enum(
  34. "OutputType", (
  35. # output is not an alias
  36. "non_alias",
  37. # output aliases an input
  38. "alias_of_input",
  39. # output **is** an input tensor
  40. "is_input",
  41. # output has a ._base tensor, which is a graph intermediate.
  42. # We need to return its ._base as a graph output,
  43. # so its requires_grad info is populated correctly.
  44. # Instructs the runtime code to regenerate the current output
  45. # from a base tensor, graph_intermediates[base_idx]
  46. "alias_of_intermediate_save_as_output",
  47. # Same as above; but we don't need to explicitly add its ._base
  48. # as a graph output, because it already **is** a graph output.
  49. "alias_of_intermediate",
  50. # Same as above; but the output's ._base is **already** a user output.
  51. # Instructs the runtime code to regenerate the current output from
  52. # a base tensor, user_outputs[base_idx]
  53. "alias_of_intermediate_base_is_user_output",
  54. )
  55. )
  56. pytree._register_pytree_node(
  57. immutable_collections.immutable_list,
  58. lambda x: (list(x), None),
  59. lambda x, c: immutable_collections.immutable_list(x),
  60. )
  61. pytree._register_pytree_node(
  62. immutable_collections.immutable_dict,
  63. lambda x: (list(x.values()), list(x.keys())),
  64. lambda x, c: immutable_collections.immutable_dict(
  65. {key: value for key, value in zip(c, x)}
  66. ),
  67. )
  68. aten = torch.ops.aten
  69. # This global counter increments every time we compile a graph with
  70. # AOTAutograd. You can use this to correlate runtime error messages
  71. # with compile time (e.g., if you get an error at runtime saying
  72. # compiled graph 3 failed, you can set a breakpoint at compile time
  73. # for this graph number to investigate further at compile time.)
  74. #
  75. # NB: this is different from get_aot_compilation_context, which tracks
  76. # each underlying graph that is compiled. In contrast, AOT_COUNTER
  77. # corresponds to top-level invocations of aot_module/aot_function;
  78. # one counter is allocated per entire compiled block (but this block
  79. # may involve compiling multiple subgraphs; e.g., for forwards/backwards)
  80. AOT_COUNTER = itertools.count()
  81. KNOWN_TYPES = tuple(
  82. [torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
  83. )
  84. @contextmanager
  85. def preserve_rng_state():
  86. rng_state = torch.clone(torch.random.get_rng_state())
  87. if torch.cuda.is_available():
  88. cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
  89. try:
  90. yield
  91. finally:
  92. torch.random.set_rng_state(rng_state)
  93. if torch.cuda.is_available():
  94. torch.cuda.set_rng_state(cuda_rng_state)
  95. # Set up hooks so that during backward the fx's stack_trace is properly set
  96. callback_set = False
  97. def setup_stacktrace_preservation_hooks(roots: List):
  98. def iter_graph(roots):
  99. if not roots:
  100. return
  101. seen = set()
  102. q = collections.deque()
  103. for node in roots:
  104. if node is not None:
  105. seen.add(node)
  106. q.append(node)
  107. while q:
  108. node = q.popleft()
  109. for fn, _idx in node.next_functions:
  110. if fn in seen or fn is None:
  111. continue
  112. seen.add(fn)
  113. q.append(fn)
  114. yield node
  115. def get_callback(saved_stack_):
  116. def callback():
  117. global callback_set
  118. fx_traceback.set_stack_trace(saved_stack_)
  119. callback_set = False
  120. return callback
  121. def get_prehook(stack_):
  122. def prehook(grad_output):
  123. global callback_set
  124. if not callback_set:
  125. torch.autograd.variable.Variable._execution_engine.queue_callback(
  126. get_callback(fx_traceback.format_stack())
  127. )
  128. callback_set = True
  129. fx_traceback.set_stack_trace(stack_)
  130. return prehook
  131. def get_posthook(special_stack_):
  132. def posthook(grad_input, grad_output):
  133. fx_traceback.set_stack_trace(special_stack_)
  134. return posthook
  135. for node in iter_graph(roots):
  136. forward_node_stack = node.metadata.get("traceback_", [])
  137. node.register_prehook(get_prehook(forward_node_stack))
  138. special_stack = forward_node_stack.copy()
  139. special_stack.append(
  140. "Gradient addition node due to multiple use of tensor around:"
  141. )
  142. node.register_hook(get_posthook(special_stack))
  143. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  144. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  145. #
  146. # AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation
  147. # that are external to the graph (they show up as side effects in some way when you run the graph).
  148. #
  149. # Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions
  150. # and what they're compiled graphs looks like.
  151. # Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.
  152. #
  153. # Note [AOT Autograd: input data mutations]
  154. #
  155. # If we compile a function that mutates inputs, then those input mutations are real side effects
  156. # that a user expects to see after running the compiled graph.
  157. # However, the graph that we want to send to a backend needs to be *entirely* functional.
  158. # The way we reconcile this difference is that we remove the mutations completely from the graph that we compile
  159. # but we update the graph to return (updated_inputs, user_outputs).
  160. # In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.
  161. #
  162. # Example: original user code:
  163. # def f(x):
  164. # x.mul_(2)
  165. # out = x.mul(3)
  166. # return out
  167. #
  168. # After AOT Autograd compiles, we end up with a:
  169. # (a) compiled graph
  170. # (b) autograd.Function.forward() method, that executes the compiled graph
  171. # (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue
  172. #
  173. # The output of (a, b, c) are all written below.
  174. #
  175. # def compiled_forward_graph(x):
  176. # x_updated = x.mul(2)
  177. # out = x_updated.mul(3)
  178. # return x_updated, out
  179. #
  180. # # x_updated gets a gradient in the compiled backward
  181. # def compiled_backward_graph(grad_x_updated, grad_out):
  182. # grad_x = ...
  183. # return grad_x
  184. #
  185. # def autograd.Function.forward(x):
  186. # x_updated, out = compiled_forward_graph(x)
  187. # return x_updated, out
  188. #
  189. # def compiled_wrapper(x):
  190. # x_updated, out = autograd.Function.apply(x)
  191. # x.copy_(x_updated)
  192. # return out
  193. #
  194. # Another important thing to note is that updated inputs (due to data mutations) *do* participate
  195. # in the compiled backward graph! Since the compiled forward graph gets N extra outputs
  196. # (due to updated inputs showing up as graph outputs),
  197. # The compiled backward gets an additional N inputs.
  198. # That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input
  199. # back to the original input.
  200. # Note [AOT Autograd: input metadata mutations]
  201. #
  202. # For the same reason as input mutations, we also don't put input metadata mutations in the graph.
  203. # Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph
  204. #
  205. # Example: original user code:
  206. # def f(x):
  207. # x.t_()
  208. # out = x.mul(3)
  209. # return out
  210. #
  211. # AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
  212. # def compiled_forward_graph(x):
  213. # x_updated = x.t()
  214. # out = x_updated.mul(3)
  215. # return x_updated, out
  216. #
  217. # # x_updated does *not* get a gradient in the compiled backward
  218. # def compiled_backward_graph(grad_out):
  219. # grad_x = ...
  220. # return grad_x
  221. #
  222. # def autograd.Function.forward(x):
  223. # x_updated, out = compiled_forward_graph(x)
  224. # return x_updated, out
  225. #
  226. # def compiled_wrapper(x):
  227. # x_updated, out = autograd.Function.apply(x)
  228. # x.as_strided_(x_updated)
  229. # return out
  230. # Note [AOT Autograd: outputs aliasing inputs or intermediates!]
  231. #
  232. # AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!
  233. # Why?
  234. # (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.
  235. # (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,
  236. # in an epilogue.
  237. # For outputs that alias inputs, we do the following:
  238. # (a) *still* return the aliased output as a graph output
  239. # (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.
  240. #
  241. # For outputs that alias *intermediates*, we do the following:
  242. # (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward
  243. # (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).
  244. # You might wonder why we return the aliased output directly in the graph (and making the graph compute it),
  245. # only to not return it and instead generate a fresh alias off of the intermediate,
  246. # instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:
  247. # (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call
  248. # (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.
  249. # This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,
  250. # when it has a different set of strides.
  251. # By including the view op directly in the graph, inductor takes that into account when deciding what memory format
  252. # the graph intermediate should be.
  253. #
  254. # Another important thing to note is how our traced backward() graph handles aliases.
  255. # (this applies to outputs aliasing inputs, outputs aliasing intermediates,
  256. # *and* updated inputs returned in the compiled forward due to metadata-only mutations).
  257. # Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph
  258. # It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly
  259. # at the end of the forward.
  260. #
  261. # Example: original user code:
  262. # def f(x):
  263. # out1 = x.t()
  264. # intermediate = x.mul(2)
  265. # out2 = intermediate.view(-1)
  266. # return out1, out2
  267. #
  268. # AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
  269. # def compiled_forward_graph(x):
  270. # out1 = x.t()
  271. # intermediate = x.mul(2)
  272. # out2 = intermediate.view(-1)
  273. # # the compiled graph also returns the intermediate
  274. # return out1, out2, intermediate
  275. #
  276. # # intermediate gets a gradient in the compiled backward.
  277. # # both output aliases (out1 and out2) do not.
  278. # def compiled_backward_graph(grad_intermediate):
  279. # grad_x = ...
  280. # return grad_x
  281. #
  282. # def autograd.Function.forward(x):
  283. # out1, out2, intermediate = compiled_forward_graph(x)
  284. # return out1, out2, intermediate
  285. #
  286. # def compiled_wrapper(x):
  287. # out1, out2, intermediate = autograd.Function.apply(x)
  288. # # regenerate out1 from the input
  289. # out1_regenerated = out1._view_func(x)
  290. # # regenerate out1 from the intermediate
  291. # out2_regenerated = out2._view_func(intermediate)
  292. # return out1_regenerated, out2_regenerated
  293. # Note [AOT Autograd: mutations to inputs that alias other inputs]
  294. #
  295. # Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.
  296. # AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.
  297. # That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias
  298. # given the mutation that occurred.
  299. #
  300. # This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input
  301. # in the compiled function, and we regenerate the original (aliased) inputs directly off of the base
  302. # inside of the compiled function.
  303. #
  304. # See merge_view_inputs() for more detailed info.
  305. #
  306. # Example: original user code:
  307. # def f(x, x_view):
  308. # x.mul_(2)
  309. # out = x * x_view
  310. # return out
  311. # f(x, x.view(-1))
  312. #
  313. # AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
  314. # def compiled_forward_graph(base)
  315. # x = generate_x(base)
  316. # x_view = generate_x_view(base)
  317. # x_updated = x.mul(2)
  318. # x_view_updated = x_updated.view(-1)
  319. # out = x_updated * x_view_udpated
  320. # return x_updated, out
  321. #
  322. # # The calling convention change from (aliases) -> (base) happens
  323. # # *outside* of the autograd.Function.forward().
  324. # # That means the forward() only has 1 input (base),
  325. # # and the backward() only has 1 output (grad_base)
  326. # def compiled_backward_graph(grad_out):
  327. # grad_base = ...
  328. # return grad_base
  329. #
  330. # def autograd.Function.forward(base):
  331. # x_updated, out = compiled_forward_graph(base)
  332. # return x_updated, out
  333. #
  334. # # The compiled wrapper is where we create synthetic bases.
  335. # # The info on which inputs are mutated is also tracked *before* synthetic base creation.
  336. # def compiled_wrapper(x, x_view):
  337. # base = merge_view_inputs(x, x_view)
  338. # x_updated, out = autograd.Function.apply(base)
  339. # # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.
  340. # x.copy_(x_updated)
  341. # return out
  342. #
  343. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  344. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  345. # This class stores info about every user output.
  346. @dataclass(frozen=True)
  347. class OutputAliasInfo:
  348. # Tells us if this output is:
  349. # (1) a regular (non-aliased) output
  350. # (2) an alias of a forward input
  351. # (3) **is** a forward input (special case of "alias_of_input")
  352. # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)
  353. # (5) an alias of an intermediate, that explicitly requires returning the intermediate
  354. # as a graph output
  355. # (6) an alias of an intermediate, where that intermediate is also a user output
  356. output_type: OutputType
  357. # If (1) above, then
  358. # - base_idx is None
  359. # If (2) or (3) above, then
  360. # - Tells us that the base of this alias is user_fwd_input[base_idx]
  361. # (This is an index into the inputs *before* we make synthetic bases)
  362. # If (4) or (5) above, then
  363. # - Tells us that the base of this alias is output_graph_intermediates[base_idx]
  364. # here, this refers to the index of the *direct* traced
  365. # If (6) above, then:
  366. # - Tells us that the base of this alias is output_user_fwds[base_idx]
  367. # here, this refers to the index of the *direct* traced
  368. base_idx: Optional[int]
  369. # This class tells us info about user inputs.
  370. @dataclass(frozen=True)
  371. class InputAliasInfo:
  372. mutates_data: bool
  373. mutates_metadata: bool
  374. # This class encapsulates all aliasing + mutation info we need about the forward graph
  375. # See a more detailed overview of the edge case handling at
  376. # https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
  377. @dataclass()
  378. class ViewAndMutationMeta:
  379. # length = # user inputs
  380. # This gives us info about every input, and what sort of mutation happened to it (if any)
  381. input_info: List[InputAliasInfo]
  382. # length = # user outputs
  383. # This gives us info about every output (mostly around whether it aliases other tensors)
  384. output_info: List[OutputAliasInfo]
  385. # length = # mutated inps + # user outputs
  386. # For every output *and* mutated input returned from the forward,
  387. # tells us whether or not the output should require gradients or not
  388. requires_grad_info: List[bool]
  389. # length = the number of intermediate bases appended as outputs to the end of the forward graph.
  390. # Note: this is not necessarily the same thing as:
  391. # len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])
  392. # Because outputs might share a ._base, or an output's ._base might itself be
  393. # another user output (in both cases, we won't redundantly append bases to the end of the graph)
  394. num_intermediate_bases: int
  395. # For inference only: instructs us to keep data-only input mutations directly in the graph
  396. keep_input_mutations: int
  397. def __post_init__(self):
  398. # pre-compute the indices of the inputs that are mutated.
  399. # When keep_input_mutations is set, we don't need to worry about our epilogue
  400. # handling data-only mutations, because we keep them directly in the graph.
  401. mutated_inp_indices = [
  402. i for i, m in enumerate(self.input_info) if m.mutates_metadata or (not self.keep_input_mutations and m.mutates_data)
  403. ]
  404. aliased_out_indices = [
  405. i
  406. for i, m in enumerate(self.output_info)
  407. if m.output_type != OutputType.non_alias
  408. ]
  409. # This is pre-computed in post_init for perf.
  410. # It contains the index of every element
  411. # of input_info that corresponds to a mutation (data or metadata or both)
  412. self.mutated_inp_indices = mutated_inp_indices
  413. # This is pre-computed for perf.
  414. # It contains the index of every element
  415. # of output_info that corresponds to an alias (either of an input or intermediate)
  416. self.aliased_out_indices = aliased_out_indices
  417. # This class exists because:
  418. # - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs
  419. # - we only care about the metadata on those aliases, so we can regenerate them.
  420. # We do not want them to participate in the autograd.Function.
  421. # We do that by wrapping them in an opaque class, so the autograd.Function
  422. # does not know to treat them as tensors.
  423. @dataclass(frozen=True)
  424. class TensorAlias:
  425. alias: torch.Tensor
  426. def has_same_metadata(t1, t2):
  427. return (
  428. t1.size() == t2.size()
  429. and t1.stride() == t2.stride()
  430. and t1.storage_offset() == t2.storage_offset()
  431. )
  432. def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad):
  433. # Try to do view-replay if possible.
  434. # fall back to .as_strided() if we can't.
  435. if target_meta_tensor._base is not None:
  436. # The base that we want to replay our view off of might have a different shape than the view's original base.
  437. b = target_meta_tensor._base
  438. abt = aliased_base_tensor
  439. # Don't unnecessarily call as_strided if nothing changed; as_strided's
  440. # backward is poorly implemented and slow
  441. if abt is not b and (
  442. abt.size() != b.size() or
  443. abt.stride() != b.stride() or
  444. abt.storage_offset() != b.storage_offset()
  445. ):
  446. reshaped_base_tensor = aliased_base_tensor.as_strided(
  447. b.size(), b.stride(), b.storage_offset()
  448. )
  449. else:
  450. reshaped_base_tensor = aliased_base_tensor
  451. out = target_meta_tensor._view_func(reshaped_base_tensor)
  452. # This shape mismatch can happen due to a bug in inplace/view handling in autograd.
  453. # Try putting a breakpoint here and running
  454. # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`
  455. # Also, https://github.com/pytorch/pytorch/issues/49825
  456. #
  457. # As a stopgap, we'll fall back to as_strided.
  458. if out is not None and out.shape == target_meta_tensor.shape:
  459. if aliased_base_tensor.requires_grad and not target_requires_grad:
  460. out = out.detach()
  461. elif not aliased_base_tensor.requires_grad and target_requires_grad:
  462. out.requires_grad_(True)
  463. return out
  464. size = target_meta_tensor.size()
  465. stride = target_meta_tensor.stride()
  466. storage_offset = target_meta_tensor.storage_offset()
  467. if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex():
  468. aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(
  469. size, stride, storage_offset
  470. )
  471. elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex():
  472. aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(
  473. size, stride, storage_offset
  474. )
  475. else:
  476. aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
  477. # For outputs aliasing inputs, we need to check if the requires-gradness has changed.
  478. if aliased_base_tensor.requires_grad and not target_requires_grad:
  479. aliased_out = aliased_out.detach()
  480. elif not aliased_base_tensor.requires_grad and target_requires_grad:
  481. aliased_out.requires_grad_(True)
  482. return aliased_out
  483. def to_fun(t):
  484. if isinstance(t, Tensor):
  485. return torch._to_functional_tensor(t, mirror_autograd_meta=True)
  486. else:
  487. return t
  488. def from_fun(t):
  489. if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
  490. return t
  491. torch._sync(t)
  492. return torch._from_functional_tensor(t)
  493. # This is a version of functionalization that is specifically designed
  494. # for the AOTAutograd use case.
  495. #
  496. # Unlike functorch's variant, this doesn't use the functorch level system,
  497. # instead it directly uses PyTorch's conventional dispatcher to hit the
  498. # functionalization key. In particular, this means that FunctionalTensorWrapper
  499. # can have autograd data stored directly on it.
  500. #
  501. # In typical AOTAutograd usage, the dispatch key order will look like:
  502. #
  503. # Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
  504. # outer tensor inner tensor
  505. #
  506. # Returns:
  507. # - ViewAndMutationMeta, telling us metadata about the inputs and outputs
  508. # - The list of outputs from the forward, but **only** the outputs that we need
  509. # to pass in as tangents into the backward.
  510. # Specifically, aliased outputs from the forward get regenerated, and don't participate
  511. # in the compiled backward function.
  512. def run_functionalized_fw_and_collect_metadata(
  513. f,
  514. *,
  515. keep_input_mutations: bool
  516. ) -> Tuple[ViewAndMutationMeta, List[Any]]:
  517. memo = {}
  518. def to_fun(t):
  519. if isinstance(t, Tensor):
  520. if t in memo:
  521. return memo[t]
  522. r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
  523. memo[t] = r
  524. return r
  525. else:
  526. return t
  527. def from_fun(t):
  528. if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
  529. return t
  530. torch._sync(t)
  531. return torch._from_functional_tensor(t)
  532. @wraps(f)
  533. def inner(*flat_args):
  534. # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
  535. assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)
  536. input_info: List[InputAliasInfo] = []
  537. output_info: List[OutputAliasInfo] = []
  538. input_requires_grad_info: List[bool] = []
  539. output_requires_grad_info: List[bool] = []
  540. flat_f_args = pytree.tree_map(to_fun, flat_args)
  541. torch._enable_functionalization(reapply_views=True)
  542. try:
  543. # precondition: The passed in function already handles unflattening inputs + flattening outputs
  544. flat_f_outs = f(*flat_f_args)
  545. finally:
  546. torch._disable_functionalization()
  547. # Inspect the state of the input tensor functional wrapper to detect input mutation info
  548. # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
  549. for (i, (arg, f_arg)) in enumerate(zip(flat_args, flat_f_args)):
  550. if not isinstance(arg, Tensor):
  551. new_arg = arg
  552. else:
  553. torch._sync(f_arg)
  554. new_arg = torch._from_functional_tensor(f_arg)
  555. if arg is not new_arg:
  556. if StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()):
  557. mutates_data = False
  558. mutates_metadata = True
  559. else:
  560. mutates_data = True
  561. mutates_metadata = not has_same_metadata(arg, new_arg)
  562. # Only track requires_grad info on *mutated* inputs,
  563. # because they show up in the autograd.Function.forward as outputs
  564. input_requires_grad_info.append(
  565. isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
  566. )
  567. else:
  568. mutates_data = False
  569. mutates_metadata = False
  570. input_info.append(InputAliasInfo(
  571. mutates_data=mutates_data,
  572. mutates_metadata=mutates_metadata
  573. ))
  574. # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediiate,
  575. # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view
  576. # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad
  577. # on the base tensor, but we are obligated to properly set requires-gradness on the real output.
  578. num_mutated_inps = len(
  579. [x for x in input_info if x.mutates_data or x.mutates_metadata]
  580. )
  581. inp_storage_refs = {
  582. StorageWeakRef(inpt.untyped_storage()): idx
  583. for idx, inpt in enumerate(flat_f_args)
  584. if isinstance(inpt, torch.Tensor)
  585. }
  586. # We need inp tensor id's to be able to tell if an outputs **are** inputs.
  587. inp_tensor_ids = {
  588. id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)
  589. }
  590. # We need output tensor id's to tell if any output._base` attributes **are** other outputs.
  591. # (This is also a dict because we need to know that output's index, so we can regenerate
  592. # the alias from it).
  593. out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
  594. # maps the id of an intermediate base to its index in the output of the compiled forward
  595. intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
  596. intermediate_bases: List[torch.Tensor] = []
  597. for o in flat_f_outs:
  598. if (
  599. isinstance(o, torch.Tensor)
  600. and StorageWeakRef(o.untyped_storage()) in inp_storage_refs
  601. ):
  602. base_idx = inp_storage_refs[StorageWeakRef(o.untyped_storage())]
  603. is_input_tensor = id(o) in inp_tensor_ids
  604. if is_input_tensor:
  605. output_type = OutputType.is_input
  606. else:
  607. output_type = OutputType.alias_of_input
  608. # We only need to handle the intermediate base case when both
  609. # the intermediate base and the output require gradients.
  610. # See Note [AOT Autograd: outputs aliasing inputs or intermediates!]
  611. elif (
  612. isinstance(o, torch.Tensor)
  613. and o._base is not None
  614. and o.requires_grad
  615. and o._base.requires_grad
  616. ):
  617. # First, check if o's ._base is an existing output
  618. maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None)
  619. if maybe_existing_out_idx is not None:
  620. # Special case where the output is an alias of a graph intermediate, but that intermediate
  621. # is itself also a user output.
  622. output_type = OutputType.alias_of_intermediate_base_is_user_output
  623. base_idx = maybe_existing_out_idx
  624. else:
  625. # Next, check if o's ._base is an intermediate base that we already returned
  626. maybe_existing_base_output_idx = intermediate_base_tensor_id_to_output_idx.get(
  627. id(o._base), None
  628. )
  629. if maybe_existing_base_output_idx is not None:
  630. output_type = OutputType.alias_of_intermediate
  631. base_idx = maybe_existing_base_output_idx
  632. else:
  633. # Otherwise, take o._base and explicitly return it as an output in the compiled graph
  634. new_out_idx = len(intermediate_bases)
  635. base_idx = new_out_idx
  636. # Indicate to the logic later on (when we trace the joint)
  637. # that this particular output should get it's ._base appended to the forward graph outputs
  638. output_type = OutputType.alias_of_intermediate_save_as_output
  639. intermediate_base_tensor_id_to_output_idx[id(o._base)] = new_out_idx
  640. intermediate_bases.append(o._base)
  641. else:
  642. output_type = OutputType.non_alias
  643. base_idx = None
  644. out_info = OutputAliasInfo(
  645. output_type=output_type,
  646. base_idx=base_idx,
  647. )
  648. output_info.append(out_info)
  649. output_requires_grad_info.append(
  650. isinstance(o, torch.Tensor) and o.requires_grad
  651. )
  652. # Our autograd.Function.forward returns both mutated inputs and outputs,
  653. # so we need grad info on all of them.
  654. requires_grad_info = input_requires_grad_info + output_requires_grad_info
  655. assert len(requires_grad_info) == len(output_info) + len(
  656. [x for x in input_info if x.mutates_data or x.mutates_metadata]
  657. )
  658. # This analysis function returns *only* the outputs that are meant to be tangents to the backwards.
  659. # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
  660. # are *regenerated* later, and not used directly in the autograd graph
  661. f_input_tangents = [
  662. inp
  663. for inp, info in zip(flat_f_args, input_info)
  664. if info.mutates_data
  665. ]
  666. f_output_tangents = [
  667. o
  668. for o, info in zip(flat_f_outs, output_info)
  669. if info.output_type == OutputType.non_alias
  670. ]
  671. # intermediate bases are also included in the backward graph
  672. f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
  673. metadata = ViewAndMutationMeta(
  674. input_info=input_info,
  675. requires_grad_info=requires_grad_info,
  676. output_info=output_info,
  677. num_intermediate_bases=len(intermediate_bases),
  678. keep_input_mutations=keep_input_mutations,
  679. )
  680. return metadata, pytree.tree_map(from_fun, f_tangents)
  681. return inner
  682. def unpack_synthetic_bases(
  683. primals: List[Any],
  684. synthetic_base_info: Optional[List[Union[int, Tuple[int, torch.Tensor]]]],
  685. ) -> List[Any]:
  686. # This is only not None if our graph mutates a graph input that aliases another graph input.
  687. if synthetic_base_info is None:
  688. return primals
  689. f_args_inner = []
  690. for outer_idx_or_tuple in synthetic_base_info:
  691. if isinstance(outer_idx_or_tuple, int):
  692. f_args_inner.append(primals[outer_idx_or_tuple])
  693. else:
  694. outer_base_idx, view_tensor = outer_idx_or_tuple
  695. outer_base = primals[outer_base_idx]
  696. view_arg = gen_alias_from_base(
  697. outer_base, view_tensor, view_tensor.requires_grad
  698. )
  699. f_args_inner.append(view_arg)
  700. return f_args_inner
  701. # This class contains all the metadata we care about for the current function we're compiling.
  702. # This data is needed both at trace time and at runtime.
  703. @dataclass
  704. class CompiledRuntimeMetadata:
  705. # This type / object should be cleaned up
  706. # See Note [Synthetic Base Info Metadata]
  707. synthetic_base_info: Optional[List[Union[int, Tuple[int, torch.Tensor]]]]
  708. fw_metadata: ViewAndMutationMeta
  709. def __post_init__(self):
  710. self.num_outputs = len(self.fw_metadata.output_info)
  711. self.num_outputs_non_aliased = len(
  712. [x for x in self.fw_metadata.output_info if x.output_type == OutputType.non_alias]
  713. )
  714. self.num_outputs_aliased_to_inputs = len(
  715. [
  716. x
  717. for x in self.fw_metadata.output_info
  718. if x.output_type in [
  719. OutputType.alias_of_input,
  720. OutputType.is_input,
  721. ]
  722. ]
  723. )
  724. self.num_outputs_aliased_to_intermediates = len(
  725. [
  726. x
  727. for x in self.fw_metadata.output_info
  728. if x.output_type in [
  729. OutputType.alias_of_intermediate,
  730. OutputType.alias_of_intermediate_save_as_output,
  731. OutputType.alias_of_intermediate_base_is_user_output,
  732. ]
  733. ]
  734. )
  735. self.num_outputs_aliased = (
  736. self.num_outputs_aliased_to_inputs + self.num_outputs_aliased_to_intermediates
  737. )
  738. self.num_mutated_data_inputs = len(
  739. [x for x in self.fw_metadata.input_info if x.mutates_data]
  740. )
  741. self.num_mutated_metadata_inputs = len(
  742. [
  743. x
  744. for x in self.fw_metadata.input_info
  745. if x.mutates_metadata
  746. ]
  747. )
  748. self.num_mutated_metadata_only_inputs = len(
  749. [
  750. x
  751. for x in self.fw_metadata.input_info
  752. if not x.mutates_data and x.mutates_metadata
  753. ]
  754. )
  755. self.num_mutated_inputs = self.num_mutated_data_inputs + self.num_mutated_metadata_only_inputs
  756. # This function takes in a tensor t, and returns one of t, t.view(), or t.clone().
  757. # When tracing the joint forward + backward, for any inputs in the graph that are mutated,
  758. # we need to clone them first (and similarly for metadata-only mutations, we need to view them first).
  759. # The idea is that when we trace the backward, we need to pass in the *original* primals
  760. # to autograd.grad(), before they were mutated.
  761. # Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.
  762. # This means that "idx" here represents the index of the (potentially) synthetic base.
  763. # What we need to do is:
  764. # (1) map the current (post-synthetic-base calling convention) input argument index
  765. # to int index pre-synthetic-base-calling-convention.
  766. # (2) There could be multiple, if this index corresponds to a synthetic base
  767. # that has multiple input aliases.
  768. # (3) If any of those corresponding inputs get metadata mutations, then we clone the base.
  769. def maybe_to_fresh_input(idx, t, meta):
  770. if not isinstance(t, Tensor):
  771. return t
  772. if meta.synthetic_base_info is None:
  773. outer_aliased_indices_of_current_base_arg = [idx]
  774. else:
  775. outer_aliased_indices_of_current_base_arg = [
  776. # For every argument index in the outer calling convention (before synthetic bases)
  777. # find its index in the inner calling convention.
  778. # if it matches the index of our current arg (idx), track the outer argument's index (i)
  779. i
  780. for i, outer_idx_or_tuple in enumerate(meta.synthetic_base_info)
  781. if (isinstance(outer_idx_or_tuple, int) and outer_idx_or_tuple == idx)
  782. or (
  783. isinstance(outer_idx_or_tuple, tuple)
  784. and outer_idx_or_tuple[0] == idx
  785. )
  786. ]
  787. if any(
  788. meta.fw_metadata.input_info[i].mutates_data
  789. for i in outer_aliased_indices_of_current_base_arg
  790. ):
  791. # Make sure the primal we pass to autograd.grad()
  792. # sees the tensor before the mutation
  793. return t.clone()
  794. if any(
  795. meta.fw_metadata.input_info[i].mutates_metadata and not meta.fw_metadata.input_info[i].mutates_data
  796. for i in outer_aliased_indices_of_current_base_arg
  797. ):
  798. # Make sure the primal we pass to autograd.grad()
  799. # sees the tensor before the metadata mutation
  800. return t.view(t.shape)
  801. return t
  802. # This function takes in a forward fn, runs it, and (optionally) runs autograd to compute the joint.
  803. # When maybe_tangents is None, we only run the forward. Otherwise we run the "joint" forward + backward.
  804. # Preconditions:
  805. # - fn corresponds to the flattened user fw function, with duplicate inputs removed
  806. # - functionalization is turned on (and inputs are wrapped in functional tensors)
  807. # - Synthetic bases have been *removed* (we've taken views on them corresponding to the user argument views).
  808. # - primals_after_cloning are what we run our forward function on. It is identical to primals_before_cloning,
  809. # except that every input we know will be mutated in the forward has been cloned.
  810. # We run our forward on primals_after_cloning (potentially mutating some inputs), and then compute our gradients
  811. # w.r.t. primals_before_cloning (so we properly capture the mutation in our gradient computation).
  812. # Importantly, due functionalization + some autograd.Function constraints, this function can return EXTRA outputs
  813. # compared to what the original user forward returns.
  814. #
  815. # If we are only running the forward (and not computing the joint):
  816. # - Our function will return (updated_inputs, fw_outs)
  817. #
  818. # If we are running the forward + backward (computing the joint):
  819. # - Our function will return (updated_inputs, fw_outs, intermediate_bases), (gradients)
  820. #
  821. # Finally, if keep_input_mutations is set, then we will explicitly *not* return updated inputs, for any inputs
  822. # that experienced data-only mutations.
  823. # Instead, we are relying on the logic in create_forward_or_joint_functionalized to manually perform the input mutations,
  824. # keeping them directly in the traced graph.
  825. def forward_or_joint(
  826. fn: Callable,
  827. primals_before_cloning: List[Any],
  828. primals_after_cloning: List[Any],
  829. maybe_tangents: Optional[List[Any]],
  830. meta: CompiledRuntimeMetadata,
  831. keep_input_mutations: bool,
  832. ) -> Any:
  833. outs = fn(*primals_after_cloning)
  834. assert len(meta.fw_metadata.output_info) == len(outs)
  835. # The compiled fw will return mutated input tensors, *including* metadata-only mutation.
  836. # However, if keep_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.
  837. # (because data-only input mutations are handled directly in the compiled graph)
  838. if keep_input_mutations:
  839. mutated_inputs_to_return = [
  840. x
  841. for (i, x) in enumerate(primals_after_cloning)
  842. if meta.fw_metadata.input_info[i].mutates_metadata
  843. ]
  844. else:
  845. mutated_inputs_to_return = [
  846. x
  847. for (i, x) in enumerate(primals_after_cloning)
  848. if meta.fw_metadata.input_info[i].mutates_data or meta.fw_metadata.input_info[i].mutates_metadata
  849. ]
  850. # Case 1: We are just tracing the forward; not the joint forward + backward.
  851. if maybe_tangents is None:
  852. return *mutated_inputs_to_return, *outs
  853. else:
  854. tangents = maybe_tangents
  855. # Case 2: We are tracing the joint forward backward.
  856. # This also requires us to:
  857. # - update the graph to return intermediate bases
  858. # - Figure out what grad_outputs to pass into the backward
  859. # - (this includes intermediate bases in the forward, and forward inputs that had data mutations)
  860. # - actually call autograd.grad to trace the backward.
  861. intermediate_bases = []
  862. for o, info in zip(outs, meta.fw_metadata.output_info):
  863. if info.output_type == OutputType.alias_of_intermediate_save_as_output:
  864. intermediate_bases.append(o._base)
  865. assert meta.fw_metadata.num_intermediate_bases == len(intermediate_bases)
  866. # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw
  867. # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
  868. # which we *should* send to grad()
  869. outputs_for_grad = [
  870. x
  871. for (i, x) in enumerate(outs)
  872. if meta.fw_metadata.output_info[i].output_type == OutputType.non_alias
  873. ]
  874. # Pass any (non-aliased) mutated inputs in as tangents, since they'll be returned as outputs in the fw
  875. # Important: the traced joint fw/bw will return updated inputs with data mutations,
  876. # but *not* with metadata mutations.
  877. # Instead, we shunt the updated metadata around externally
  878. # and update the input's metadata outside of the autograd.Function
  879. mutated_inputs_for_grad = [
  880. x
  881. for (i, x) in enumerate(primals_after_cloning)
  882. if meta.fw_metadata.input_info[i].mutates_data
  883. ]
  884. # The tensors that we include in the backward graph are:
  885. # - inputs that recieve *data* mutations (not metadata-only; those are recomputed later)
  886. # - outputs that are not aliased (aliased outputs are recomputed later)
  887. # - intermediate ._base tensors of aliased outputs (we use those later to recompute the aliased outputs)
  888. fw_outs_to_grad = mutated_inputs_for_grad + outputs_for_grad + intermediate_bases
  889. assert len(tangents) == len(fw_outs_to_grad)
  890. # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)
  891. fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases
  892. # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)
  893. # and not primals_before_cloning (the preserved inputs, pre-mutation, that we pass to grad())
  894. for i, arg in enumerate(primals_after_cloning):
  895. if not isinstance(arg, Tensor):
  896. continue
  897. torch._sync(arg)
  898. # Get the inputs that need gradients
  899. grad_primals = []
  900. inputs_needs_grads = []
  901. # Note that we're not using primals_before_cloning here,
  902. # being carefully not to pass any mutated inputs into autograd.grad()
  903. for p in primals_before_cloning:
  904. is_grad_tensor = isinstance(p, Tensor) and p.requires_grad
  905. inputs_needs_grads.append(is_grad_tensor)
  906. if is_grad_tensor:
  907. grad_primals.append(p)
  908. # Get the outputs that need gradients
  909. needed_outs = []
  910. needed_tangents = []
  911. for out, tangent in zip(fw_outs_to_grad, tangents):
  912. if isinstance(out, Tensor) and out.requires_grad:
  913. # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
  914. # The issue is that we are sensitive to decomps that don't accurately maintain
  915. # their output's _base.shape compared to eager mode, and this helps mitigate a bit.
  916. needed_outs.append(
  917. out if out.shape == tangent.shape else out.view(tangent.shape)
  918. )
  919. needed_tangents.append(tangent.requires_grad_(True))
  920. setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
  921. backward_out = []
  922. # Call the backwards pass
  923. if grad_primals:
  924. with fx_traceback.preserve_node_meta():
  925. backward_out = torch.autograd.grad(
  926. needed_outs,
  927. grad_primals,
  928. grad_outputs=needed_tangents,
  929. allow_unused=True,
  930. )
  931. backward_out_iter = iter(backward_out)
  932. return fw_outs_to_return, [
  933. next(backward_out_iter) if i else None for i in inputs_needs_grads
  934. ]
  935. # This function expands synthetic base arguments into the original aliased inputs that the user passed in.
  936. # Preconditions:
  937. # - fn corresponds to the flattened user fw function, with duplicate inputs removed
  938. # - functionalization is turned on (and inputs are wrapped in functional tensors)
  939. # - both primals args **include** synthetic bases.
  940. # "primals_after_cloning" just corresponds to "primals_before_cloning", but with some inputs (optionally) cloned.
  941. # "primals_before_cloning" is unused, and is only needed so we can pass the correct leaf tensors into autograd.
  942. def flat_fn_with_synthetic_bases_expanded(
  943. fn: Callable,
  944. primals_before_cloning: List[Any],
  945. primals_after_cloning: List[Any],
  946. maybe_tangents: Optional[List[Any]],
  947. meta: CompiledRuntimeMetadata,
  948. keep_input_mutations: bool
  949. ):
  950. # This is where we handle the calling convention around synthetic bases.
  951. # We need to make sure that we convert any synthetic base arguments into views
  952. # *after* we clone inputs for autograd (see below), to preserve the view relationship.
  953. primals = unpack_synthetic_bases(primals_after_cloning, meta.synthetic_base_info)
  954. assert len(meta.fw_metadata.input_info) == len(primals)
  955. outs = forward_or_joint(fn, primals_before_cloning, primals, maybe_tangents, meta, keep_input_mutations)
  956. return outs
  957. # This function adds extra clone() calls on any inputs in the forward that get mutated.
  958. # It *only* does this if we plan on performing autograd on fn.
  959. # The idea here is that when computing grdients w.r.t. inputs, we need to compute our gradients
  960. # w.r.t. the inputs *before* they were mutated!
  961. # Preconditions:
  962. # - fn corresponds to the flattened user fw function, with duplicate inputs removed
  963. # - primals **includes** synthetic bases. Importantly, if a synthetic base is mutated,
  964. # we need to clone it *before* taking views off of it (if we clone the views they won't be views anymore)
  965. # - functionalization is turned on (and inputs are wrapped in functional tensors)
  966. def flat_fn_no_input_mutations(
  967. fn: Callable,
  968. primals: List[Any],
  969. maybe_tangents: Optional[List[Any]],
  970. meta: CompiledRuntimeMetadata,
  971. keep_input_mutations: bool
  972. ):
  973. # When tracing the joint fwd + bwd, making sure to clone any inputs that are mutated first.
  974. # We need to ensure that the inputs we pass to autograd.grad() are the *original*
  975. # inputs, and not their mutated values.
  976. if maybe_tangents is not None:
  977. primals_after_cloning = [
  978. maybe_to_fresh_input(i, t, meta) for i, t in enumerate(primals)
  979. ]
  980. else:
  981. primals_after_cloning = primals
  982. outs = flat_fn_with_synthetic_bases_expanded(fn, primals, primals_after_cloning, maybe_tangents, meta, keep_input_mutations)
  983. return outs
  984. # This creates the final function that we want to trace using make_fx(),
  985. # in both aot_dispatch_autograd and aot_dispatch_base.
  986. # Preconditions:
  987. # - fn corresponds to the user's fw function
  988. # - fn arguments have been flattened, duplicate arguments have been handled
  989. # - In the returned function, the "primals" arguments *includes* synthetic bases.
  990. # This function does the work of functionalizing the input function,
  991. # and performing copy_() calls at the end of the function if `keep_input_mutations` is set.
  992. # The function returned has signature that is either:
  993. # (1) "traced_fn(primals: List[Any])" if trace_joint is False
  994. # (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True
  995. def create_forward_or_joint_functionalized(
  996. fn,
  997. *,
  998. meta: CompiledRuntimeMetadata,
  999. trace_joint: bool,
  1000. keep_input_mutations: bool
  1001. ):
  1002. def functionalized_f_helper(primals, maybe_tangents=None):
  1003. # Convention: this function is used to trace both the joint, and just the forward (for inference).
  1004. # When trace_joint is set, tangents should be passed in.
  1005. assert (maybe_tangents is not None) == trace_joint
  1006. # Wrap inputs into functional wrappers
  1007. f_primals = pytree.tree_map(to_fun, primals)
  1008. f_tangents = None if maybe_tangents is None else pytree.tree_map(to_fun, maybe_tangents)
  1009. torch._enable_functionalization(reapply_views=True)
  1010. try:
  1011. # Run the joint
  1012. f_outs = flat_fn_no_input_mutations(fn, f_primals, f_tangents, meta, keep_input_mutations)
  1013. finally:
  1014. torch._disable_functionalization()
  1015. if keep_input_mutations:
  1016. # Note: This is a bit annoying. There's a layering issue here, where:
  1017. # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
  1018. # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
  1019. # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
  1020. # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
  1021. # This makes it pretty difficult for this logic to operate on synthetic bases.
  1022. # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
  1023. # (unpacked) input aliases, instead of the synthetic base.
  1024. # The result is that ideally this function shouldn't have to worry about synthetic bases
  1025. # (unpacking them happens underneath this function),
  1026. # but we actually do need to unpack the synthetic bases when performing the copy_'s to keep input mutations around.
  1027. # Example case where this could be important:
  1028. #
  1029. # def f(x, y):
  1030. # x.mul_(2)
  1031. # y.mul_(3)
  1032. # return x, y
  1033. # a = torch.ones(1'000'000)
  1034. # x, y = out(a[0:9], a[1:10])
  1035. #
  1036. # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
  1037. # a giant "updated synthetic base" and copying into a's entire storage.
  1038. primals_unpacked = unpack_synthetic_bases(primals, meta.synthetic_base_info)
  1039. f_primals_unpacked = unpack_synthetic_bases(f_primals, meta.synthetic_base_info)
  1040. assert len(meta.fw_metadata.input_info) == len(f_primals_unpacked)
  1041. for i, (inpt_old, inpt_f) in enumerate(zip(primals_unpacked, f_primals_unpacked)):
  1042. if not isinstance(inpt_f, torch.Tensor):
  1043. continue
  1044. torch._sync(inpt_f)
  1045. inpt_new = torch._from_functional_tensor(inpt_f)
  1046. if meta.fw_metadata.input_info[i].mutates_data and not meta.fw_metadata.input_info[i].mutates_metadata:
  1047. # We found an input that had a (data-only) mutation.
  1048. # Since keep_input_mutations is set, we need to faithfully apply a copy_()
  1049. # so the compiler will see the input mutation in the graph.
  1050. assert inpt_new is not inpt_old
  1051. assert has_same_metadata(inpt_new, inpt_old)
  1052. inpt_old.copy_(inpt_new)
  1053. return pytree.tree_map(from_fun, f_outs)
  1054. # the joint needs have args named "primals" and "tangents",
  1055. # which are hardcoded into the partitioning logic.
  1056. def traced_joint(primals, tangents):
  1057. return functionalized_f_helper(primals, tangents)
  1058. def traced_forward(*primals):
  1059. return functionalized_f_helper(primals)
  1060. if trace_joint:
  1061. return traced_joint
  1062. else:
  1063. return traced_forward
  1064. def normalize_as_list(x):
  1065. if isinstance(x, tuple):
  1066. return list(x)
  1067. elif isinstance(x, list):
  1068. return x
  1069. return [x]
  1070. aot_autograd_decompositions = {}
  1071. # This is a list since looking forward, we can have this arbitrarily nested.
  1072. graph_being_compiled: List[str] = []
  1073. # TODO: It would be nice to reset the numbering every time aot_id goes
  1074. # up, but this is annoying to do right now (because we don't know if
  1075. # an aot_id will come back from the dead), so right now this also happens
  1076. # to be a globally unique number too (at the cost of wobbling if you change
  1077. # how the graphs compile)
  1078. nth_graph: int = 0
  1079. model_name: str = "model"
  1080. def set_model_name(name):
  1081. global model_name
  1082. model_name = name
  1083. def get_aot_compilation_context() -> Tuple[List[str], str, int]:
  1084. return list(graph_being_compiled), model_name, nth_graph
  1085. def get_aot_graph_name() -> str:
  1086. """
  1087. Returns the name of the graph being compiled.
  1088. """
  1089. global model_name, graph_being_compiled, nth_graph
  1090. return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}"
  1091. get_graph_being_compiled = get_aot_graph_name
  1092. @contextmanager
  1093. def track_graph_compiling(aot_config, graph_name):
  1094. global graph_being_compiled
  1095. # TODO: Don't shove the aot_id in here; set it in the context
  1096. graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
  1097. yield
  1098. global nth_graph
  1099. nth_graph += 1
  1100. graph_being_compiled = []
  1101. def make_boxed_func(f):
  1102. def g(args):
  1103. return f(*args)
  1104. g._boxed_call = True
  1105. return g
  1106. def make_boxed_compiler(compiler):
  1107. @wraps(compiler)
  1108. def f(fx_g, inps):
  1109. out_f = compiler(fx_g, inps)
  1110. fx_g = make_boxed_func(out_f)
  1111. return fx_g
  1112. return f
  1113. def call_func_with_args(f, args, steal_args=False, disable_amp=False):
  1114. if not steal_args:
  1115. args = list(args)
  1116. assert isinstance(args, list)
  1117. if disable_amp:
  1118. guard = torch._C._DisableAutocast()
  1119. try:
  1120. if hasattr(f, "_boxed_call"):
  1121. out = normalize_as_list(f(args))
  1122. else:
  1123. # TODO: Please remove soon
  1124. # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
  1125. warnings.warn(
  1126. "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
  1127. "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
  1128. "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
  1129. )
  1130. out = normalize_as_list(f(*args))
  1131. finally:
  1132. if disable_amp:
  1133. del guard
  1134. return out
  1135. @dataclasses.dataclass
  1136. class AOTConfig:
  1137. """
  1138. Configuration for AOTDispatcher
  1139. """
  1140. fw_compiler: Callable
  1141. bw_compiler: Callable
  1142. partition_fn: Callable
  1143. decompositions: Dict[Callable, Callable]
  1144. num_params_buffers: int
  1145. aot_id: int
  1146. keep_inference_input_mutations: bool
  1147. def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
  1148. with enable_python_dispatcher():
  1149. _fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  1150. flat_fn,
  1151. keep_input_mutations=aot_config.keep_inference_input_mutations,
  1152. )(
  1153. *flat_args
  1154. )
  1155. _input_info = _fw_metadata.input_info
  1156. flat_args_with_views_handled, _synthetic_base_info = merge_view_inputs(
  1157. flat_args, _input_info, is_inference=True
  1158. )
  1159. metadata_ = CompiledRuntimeMetadata(
  1160. synthetic_base_info=_synthetic_base_info,
  1161. fw_metadata=_fw_metadata,
  1162. )
  1163. # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
  1164. # The cases that aot_dispatch_base doesn't need to handle include:
  1165. # - outputs that are aliases of graph intermediates
  1166. # - outputs that are aliases of graph inputs
  1167. # While cases that it does need to handle include:
  1168. # - input mutations (including when inputs are aliases of each other)
  1169. # - input metadata mutations
  1170. trace_fn = create_forward_or_joint_functionalized(
  1171. flat_fn,
  1172. meta=metadata_,
  1173. trace_joint=False,
  1174. keep_input_mutations=aot_config.keep_inference_input_mutations
  1175. )
  1176. with enable_python_dispatcher():
  1177. fw_module = make_fx(trace_fn, aot_config.decompositions)(*flat_args_with_views_handled)
  1178. if not aot_config.keep_inference_input_mutations:
  1179. # As long as we opted to remove input mutations, then
  1180. # there should be *NO* mutating ops in the graph at this point.
  1181. assert_functional_graph(fw_module.graph)
  1182. fw_module.graph.eliminate_dead_code()
  1183. fw_module.recompile()
  1184. if config.debug_graphs:
  1185. log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======")
  1186. log.debug(fw_module.print_readable(print_output=False))
  1187. disable_amp = torch._C._is_any_autocast_enabled()
  1188. context = disable_autocast_manager if disable_amp else nullcontext
  1189. with context(), track_graph_compiling(aot_config, "inference"):
  1190. compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  1191. compiled_fn = create_runtime_wrapper(
  1192. compiled_fw,
  1193. runtime_metadata=metadata_,
  1194. trace_joint=False,
  1195. keep_input_mutations=aot_config.keep_inference_input_mutations
  1196. )
  1197. return compiled_fn
  1198. def assert_functional_graph(fx_g: torch.fx.Graph):
  1199. for n in fx_g.nodes:
  1200. if isinstance(n.target, torch._ops.OpOverload):
  1201. assert not n.target._schema.is_mutable, \
  1202. f'aot_autograd expected to have an entirely functional graph, but found {n.format_node()}'
  1203. @contextmanager
  1204. def disable_autocast_manager():
  1205. guard = torch._C._DisableAutocast()
  1206. try:
  1207. yield
  1208. finally:
  1209. del guard
  1210. def are_differentiable_views(view1, view2):
  1211. if view1 is view2:
  1212. return True
  1213. if view1._base is None and view2._base is None:
  1214. return False
  1215. if view1._base is view2._base or view1._base is view2 or view1 is view2._base:
  1216. return True
  1217. return False
  1218. def same_dtype_views(view1, view2):
  1219. if view1.dtype != view2.dtype:
  1220. return False
  1221. if view1._base is not None and view1.dtype != view1._base.dtype:
  1222. return False
  1223. if view2._base is not None and view2.dtype != view2._base.dtype:
  1224. return False
  1225. return True
  1226. # Note [Handling mutations on an input that aliases other inputs]
  1227. # The easiest example to show-case this edge case is here:
  1228. #
  1229. # def f(a, b):
  1230. # a.mul_(2)
  1231. # out = a + b
  1232. # return out
  1233. # b = torch.ones(...)
  1234. # a = b.view(-1)
  1235. # f(a, b)
  1236. #
  1237. # In this situation, if a and b happened to be aliased, we need to trace something different!
  1238. # Suppose we had b = a.view(-1)
  1239. # (In this case, that means that `a._base is b`)
  1240. #
  1241. # We need to ensure that the aliasing relationship between a and b is preserved.
  1242. # We do that detecting the specific situation above (mutate an input that aliases another input),
  1243. # and when we do that, we create a synthetic base argument. Then inside of the traced forward,
  1244. # we regenerate a and b off of that base.
  1245. # The complete example of the transformed function looks like this:
  1246. #
  1247. # // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views
  1248. # // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph
  1249. # def traced_forward(base):
  1250. # a = base.as_strided(...)
  1251. # b = base.as_strided(...)
  1252. # a_updated = a.mul(2)
  1253. # base_updated = torch.as_strided_scatter(base, a_updated, ...)
  1254. # b_updated = base_updated.as_strided(...)
  1255. # out = a_updated + b_updated
  1256. # return a_updated, out
  1257. #
  1258. # def compiled_fn(a, b):
  1259. # // we detect that a is the "differentiable base" here
  1260. # base = a
  1261. # // In other situations, we might do either:
  1262. # // (1) a and b are both views off of some larger differentiable base
  1263. # // assert a._base is b._base and a._base is not None
  1264. # // base = a._base
  1265. # // (2) a and b both don't require gradients. Create a base from the storage
  1266. # // assert a._base is None and b._base is None
  1267. # // base = torch.Tensor(a.storage())
  1268. # a_updated, out = traced_forward(base)
  1269. # a.copy_(a_updated)
  1270. # return out
  1271. #
  1272. # This function:
  1273. # (1) Merges input views into a synthetic base argument, when any of those input views are mutated
  1274. # (2) Returns metadata telling the autograd.Function how to modify their arguments properly,
  1275. # to respect the new calling convention.
  1276. #
  1277. # The calling convention is as follows.
  1278. # Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.
  1279. # The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],
  1280. # Where the ordering of the bases is determined from the ordering of the original view args.
  1281. # baseA will come before baseB if the earliest original argument coming from baseA
  1282. # showed up earlier in the argument list than the earliest original argument coming from baseB.
  1283. #
  1284. # Example, given some tensors a, b, c, d
  1285. # call site:
  1286. # f(a, c.view(-1), b.view(-1), b, c, d)
  1287. # Modified argument list:
  1288. # c_base comes first because the first c view came earlier in arg list than the first b view
  1289. # a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases
  1290. # b_base = torch.Tensor(b.storage())
  1291. # c_base = torch.Tensor(c.storage())
  1292. # f(c_base, b_base, a, d)
  1293. def merge_view_inputs(
  1294. fwd_inputs: List[Any], mutated_input_info: List[InputAliasInfo],
  1295. *,
  1296. # The autograd case currently has more restrictions than the inference case.
  1297. is_inference: bool,
  1298. ) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
  1299. assert len(fwd_inputs) == len(mutated_input_info)
  1300. storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
  1301. base_args = []
  1302. other_args = []
  1303. for i, inpt in enumerate(fwd_inputs):
  1304. if isinstance(inpt, Tensor):
  1305. storage_ref = StorageWeakRef(inpt.untyped_storage())
  1306. storage_ref_to_idx[storage_ref].append(i)
  1307. else:
  1308. other_args.append(inpt)
  1309. # Note [Synthetic Base Info Metadata]
  1310. # This list contains metadata that tells you what the i'th argument in the inner calling convention should be.
  1311. # It's either:
  1312. # - another int (corresponding to the index in the argument list of the element from the outer calling convention)
  1313. # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
  1314. # idx corresponds to which synthetic base from the outer calling context to view
  1315. inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
  1316. for aliased_input_indices in storage_ref_to_idx.values():
  1317. if len(aliased_input_indices) <= 1 or not any(
  1318. # We only care about mutations that affect all aliases,
  1319. # so metadata mutations on an input doesn't require us to do synthetic base handling.
  1320. mutated_input_info[inpt_idx].mutates_data
  1321. for inpt_idx in aliased_input_indices
  1322. ):
  1323. for curr_idx in aliased_input_indices:
  1324. other_args.append(fwd_inputs[curr_idx])
  1325. continue
  1326. # We detected an input that was mutated, AND aliases with another input.
  1327. # we need to replace this set of aliased inputs with a single synthetic base.
  1328. # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
  1329. # and error out. We can fix them later.
  1330. # These checks are transitive, so we don't need to check every pair.
  1331. for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:]):
  1332. view1 = fwd_inputs[idx1]
  1333. view2 = fwd_inputs[idx2]
  1334. # The "inputs that are aliased but have different differentiable bases" case
  1335. # is more complicated and hopefully pretty rare. Not currently handled.
  1336. if not is_inference:
  1337. assert are_differentiable_views(
  1338. view1, view2
  1339. ), "aot_autograd() does not yet handle non-differentiable view input mutations."
  1340. # Regenerating views when reinterpreting complex / real tensors seems non-trivial,
  1341. # not handling for now
  1342. assert same_dtype_views(
  1343. view1, view2
  1344. ), "aot_autograd() does not yet handle input mutations on views with different dtypes."
  1345. non_none_bases = [
  1346. fwd_inputs[i]._base
  1347. for i in aliased_input_indices
  1348. if fwd_inputs[i]._base is not None
  1349. ]
  1350. aliases_with_none_bases = [
  1351. fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None
  1352. ]
  1353. if len(non_none_bases) == 0:
  1354. # Case where none of the aliases have a ._base
  1355. # we generate a synthetic base without gradients, and generate views off of it
  1356. # We hit this case when we have input tensors to the graph that share a storage,
  1357. # but do not have a ._base field.
  1358. # Wondering when we hit this case?
  1359. # The _base field simply says that autograd knows about the aliasing relationship,
  1360. # but sometimes we create tensors which are aliased out of the same storage but guaranteed
  1361. # to be disjoint. In these cases, we will skip setting up the _base relationship
  1362. # for performance reasons (because the fact that the tensors share the same storage
  1363. # is unobservable unless you (1) do naughty things with resize_/as_strided
  1364. # or (2) look at the storage--as we are doing here.)
  1365. # One particular example of this is optimizer steps on the LSTM module:
  1366. # LSTM parameters are packed into a contiguous storage for efficiency reasons when
  1367. # calling cuDNN kernels, so when these parameters get passed to the optimizer we will
  1368. # find they share the same storage, but do not have _base set since they are all disjoint.
  1369. #
  1370. # NOTE: There is one case where this is unsafe:
  1371. # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily
  1372. # the same shape as the "actual" base that the tensor came from.
  1373. # For the most part this is fine, because we always use as_strided()
  1374. # to generate the original aliased inputs again.
  1375. # If we were to use view-replay though, this could cause the aliased views
  1376. # to have incorrect sizes.
  1377. example_idx = aliased_input_indices[0]
  1378. example_alias = fwd_inputs[example_idx]
  1379. # Note that this function is re-used at both trace time and rutnime.
  1380. # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.
  1381. synthetic_base = torch.empty((0,), dtype=example_alias.dtype, device=example_alias.device)
  1382. # We don't actually have a convenient way of going from storage -> tensor,
  1383. # So using set_() here (we suffer some minor overhead, but this case is rare).
  1384. synthetic_base.set_(example_alias.untyped_storage())
  1385. else:
  1386. # Case where all of the aliases require gradients, and have the same _base.
  1387. synthetic_base = non_none_bases[0]
  1388. for other_base in non_none_bases[1:]:
  1389. assert (
  1390. other_base is synthetic_base
  1391. ), "aot_autograd() does not yet handle non-differentiable view input mutations."
  1392. for alias in aliases_with_none_bases:
  1393. assert (
  1394. alias is synthetic_base
  1395. ), "aot_autograd() does not yet handle non-differentiable view input mutations."
  1396. base_args.append(synthetic_base)
  1397. for curr_view_idx in aliased_input_indices:
  1398. curr_view = fwd_inputs[curr_view_idx]
  1399. base_idx = len(base_args) - 1
  1400. # We store just enough info here so that we can regenerate the view later.
  1401. # Regeneration: curr_view._view_func(args[base_idx])
  1402. inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view)
  1403. if len(base_args) == 0:
  1404. assert len(other_args) == len(fwd_inputs)
  1405. # If no synthetic bases are necessary, just return the original inputs.
  1406. return fwd_inputs, None
  1407. else:
  1408. # Otherwise, return:
  1409. # (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
  1410. # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
  1411. # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
  1412. args_to_functionalization = base_args + other_args
  1413. arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
  1414. for i, other_arg in enumerate(other_args):
  1415. new_idx = len(base_args) + i
  1416. old_idx = arg_to_old_idx_map[other_arg]
  1417. inner_calling_convention_meta[old_idx] = new_idx
  1418. # post process into a list
  1419. post_processed_calling_convention_meta: List[Union[int, Callable]] = [
  1420. -1 for _ in range(len(inner_calling_convention_meta))
  1421. ]
  1422. for k, v in inner_calling_convention_meta.items():
  1423. post_processed_calling_convention_meta[k] = v
  1424. # Quick assert: every argument in the inner calling convention should be accounted for.
  1425. for x in post_processed_calling_convention_meta:
  1426. assert x != -1
  1427. return args_to_functionalization, post_processed_calling_convention_meta
  1428. def format_guard_bug_msg(aot_config, expected):
  1429. return (
  1430. f"At compilation time, graph {aot_config.aot_id} was compiled under the "
  1431. f"assumption that {expected}, but at runtime this was not the case. "
  1432. "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch."
  1433. )
  1434. # MOTIVATION:
  1435. #
  1436. # When tracing functions for future execution, one must be careful not to pass
  1437. # in the same input tensor multiple times (e.g., f(x, x), as this can result
  1438. # in graphs that are ONLY valid if you later pass a new tensor in exactly the
  1439. # same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct
  1440. # tensors that alias each other is a different situation that is covered by
  1441. # aot_dispatch_deduplicated_autograd). Here are two examples:
  1442. #
  1443. # (1) Suppose you have a function:
  1444. #
  1445. # def f(x, y):
  1446. # return x + y
  1447. #
  1448. # If you make_fx(f)(x, x), you will trace out:
  1449. #
  1450. # def f(x, y):
  1451. # return y + y
  1452. #
  1453. # Oops!
  1454. #
  1455. # (2) For most tensors x and y, you can compute f's gradient with respect to
  1456. # these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However,
  1457. # if x is y, you will trace out a program that gets incorrect gradients:
  1458. #
  1459. # >>> x = torch.randn(1, requires_grad=True)
  1460. # >>> torch.autograd.grad(x + x, (x, x))
  1461. # (tensor([2.]), tensor([2.]))
  1462. #
  1463. # In other words, the gradient is double-counted. Deduplicating the arguments
  1464. # gives you an appropriate gradient:
  1465. #
  1466. # >>> y = torch.randn(1, requires_grad=True)
  1467. # >>> torch.autograd.grad(x + y, (x, y))
  1468. # (tensor([1.]), tensor([1.]))
  1469. #
  1470. # HOW TO DEDUPLICATE:
  1471. #
  1472. # There are a few strategies, in order of preference:
  1473. #
  1474. # 1. For every duplicate argument to the function, detach it into
  1475. # a separate leaf tensor, so that it is no longer duplicated.
  1476. #
  1477. # PRO: The resulting compiled graph works for any configuration
  1478. # of duplicated arguments.
  1479. #
  1480. # CON: It does not (naively) work if you mutate the metadata of inputs:
  1481. #
  1482. # def f(x, y):
  1483. # x.transpose_(0, 1)
  1484. # y.transpose_(0, 2)
  1485. #
  1486. # x = torch.randn(2, 3, 4)
  1487. # f(x, x)
  1488. #
  1489. # The ordering of the transposes inside f dictates whether or not
  1490. # you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute
  1491. # what metadata mutations should get applied to each input; you need to
  1492. # assume they aren't duplicates (what we do today) or preserve
  1493. # the original metadata mutations exactly in order, so that they work
  1494. # for any duplicate configuration.
  1495. #
  1496. # CON: It does not (naively) work if you mutate the data of inputs.
  1497. # In particular, leaf tensors that require grad cannot be mutated,
  1498. # this makes it impossible to differentiate with respect to the original
  1499. # base.
  1500. #
  1501. # 2. For every duplicate argument to the function, remove it, so it is
  1502. # no longer part of the "true" signature:
  1503. #
  1504. # PRO: Implemented naively, it still works for metadata/data mutation.
  1505. #
  1506. # CON: The resulting compiled graph is duplicate-specialized: it only
  1507. # works if future calls duplicate arguments in exactly the same way.
  1508. # Horribly, Dynamo doesn't guard on this at the moment. But even if
  1509. # it did, you could still end up recompiling a bunch of each duplicate.
  1510. #
  1511. # Our strategy is to do (1) if we can, and do (2) otherwise, erroring if
  1512. # Dynamo's guards are not enough. In practice, this seems to cover
  1513. # everything.
  1514. #
  1515. def aot_wrapper_dedupe(
  1516. flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, compiler_fn
  1517. ):
  1518. # Get information about whether or not flat_fn mutates its arguments
  1519. # or not
  1520. try:
  1521. with enable_python_dispatcher():
  1522. fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  1523. flat_fn,
  1524. # For the purpose of checking for dupes that are mutated,
  1525. # we always want our metadata to correctly reflect input mutations
  1526. keep_input_mutations=False,
  1527. )(
  1528. *flat_args
  1529. )
  1530. except RuntimeError as e:
  1531. log.warning(
  1532. "Failed to collect metadata on function, produced code may be suboptimal. "
  1533. "Known situations this can occur are inference mode only compilation involving "
  1534. "resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); "
  1535. "if your situation looks different please file a bug to PyTorch.",
  1536. exc_info=True,
  1537. )
  1538. # Analysis failed, fall back to duplicate specialize
  1539. # TODO: Known analysis problems:
  1540. # - resize_: TestInductorOpInfoCPU.test_comprehensive_resize__cpu_bool
  1541. # - prims: test_tmp_not_defined_issue1_cpu
  1542. pass
  1543. else:
  1544. # Strategy 1: For any input that is not mutated, we can leafify it if we
  1545. # need to remove a duplicate.
  1546. leaf_flat_args = []
  1547. args_set = set()
  1548. ok = True
  1549. for i, a in enumerate(flat_args):
  1550. if a not in args_set:
  1551. args_set.add(a)
  1552. leaf_flat_args.append(a)
  1553. elif not fw_metadata.input_info[i].mutates_data and not fw_metadata.input_info[i].mutates_metadata:
  1554. leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
  1555. else:
  1556. ok = False
  1557. break
  1558. if ok:
  1559. return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  1560. # Strategy 2: Duplicate specialize.
  1561. #
  1562. # In Haskell types, suppose you have:
  1563. #
  1564. # add_dupe_args :: DedupedArgs -> Args
  1565. # remove_dupe_args :: Args -> DedupedArgs
  1566. #
  1567. # compiler_fn
  1568. # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
  1569. # deped_compiler_fn
  1570. # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
  1571. #
  1572. # Then the code below can be written in point-free style as:
  1573. #
  1574. # deduped_compiler_fn f a c =
  1575. # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
  1576. #
  1577. # Suppose you have:
  1578. #
  1579. # [a, b, a, c]
  1580. #
  1581. # We want:
  1582. #
  1583. # remove_dupe_args([a, b, a, c]) == [a, b, c]
  1584. # add_dupe_args([a, b, c]) == [a, b, a, c]
  1585. #
  1586. # This is done via (respectively):
  1587. #
  1588. # seen_args = {a: 0, b: 1, c: 2}
  1589. # add_dupe_map = { # how to get args from the deduped list
  1590. # 0: 0,
  1591. # 1: 1,
  1592. # 2: 0,
  1593. # 3: 2,
  1594. # }
  1595. # keep_arg_mask = [True, True, False, True]
  1596. seen_args = {}
  1597. keep_arg_mask = []
  1598. add_dupe_map = {}
  1599. duped_arg_len = len(flat_args)
  1600. j = 0 # index into deduped_flat_args
  1601. for i, t in enumerate(flat_args):
  1602. if t in seen_args:
  1603. keep_arg_mask.append(False)
  1604. add_dupe_map[i] = seen_args[t]
  1605. continue
  1606. keep_arg_mask.append(True)
  1607. seen_args[t] = j
  1608. add_dupe_map[i] = j
  1609. j += 1
  1610. unique_args = j
  1611. # NB: Hot path, avoid set lookups here
  1612. # TODO: Can avoid the zip here too, probably
  1613. def remove_dupe_args(args):
  1614. return [t for t, keep in zip(args, keep_arg_mask) if keep]
  1615. def add_dupe_args(args):
  1616. return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
  1617. deduped_flat_args = remove_dupe_args(flat_args)
  1618. tracing_context = TracingContext.get()
  1619. if tracing_context:
  1620. # TODO(voz): This structure is 1:1, we could consider an alternate structure like
  1621. # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
  1622. # which feels like needless complexity for a tiny bit of efficiency at this point.
  1623. for dupe_arg_pos, kept_pos in add_dupe_map.items():
  1624. dupe_arg_dict = flat_args[dupe_arg_pos].__dict__
  1625. kept_arg_dict = flat_args[kept_pos].__dict__
  1626. if 'graph_arg_pos' in dupe_arg_dict and 'graph_arg_pos' in kept_arg_dict:
  1627. d_positions = dupe_arg_dict['graph_arg_pos']
  1628. k_positions = kept_arg_dict['graph_arg_pos']
  1629. assert(d_positions == k_positions)
  1630. if len(d_positions) > 1:
  1631. for i in range(1, len(d_positions)):
  1632. pos = d_positions[i]
  1633. pre_pos = d_positions[i - 1]
  1634. tracing_context.guards_context.aotautograd_guards.append(DuplicateInputs(pre_pos, pos))
  1635. @wraps(flat_fn)
  1636. def wrapped_flat_fn(*args):
  1637. return flat_fn(*add_dupe_args(args))
  1638. compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config)
  1639. if not hasattr(compiled_fn, "_boxed_call"):
  1640. compiled_fn = make_boxed_func(compiled_fn)
  1641. @wraps(compiled_fn)
  1642. def wrapped_compiled_fn(args):
  1643. deduped_args = remove_dupe_args(args)
  1644. args.clear()
  1645. return compiled_fn(deduped_args)
  1646. wrapped_compiled_fn._boxed_call = True
  1647. # This can be uncommented when we properly guard for duplicates,
  1648. # but right now we must not do it.
  1649. # if not config.debug_assert:
  1650. # return wrapped_compiled_fn
  1651. @wraps(wrapped_compiled_fn)
  1652. def debugged_compiled_fn(args):
  1653. # Test that the computed remove/add arg functions are an inverse
  1654. new_args = add_dupe_args(remove_dupe_args(args))
  1655. seen = {}
  1656. for i, (x, y) in enumerate(zip(new_args, args)):
  1657. seen[y] = None
  1658. assert x is y, format_guard_bug_msg(
  1659. aot_config,
  1660. f"{describe_input(i, aot_config)} would be a duplicate of "
  1661. f"{describe_input(add_dupe_map[i], aot_config)}",
  1662. )
  1663. # This is only an error if there is metadata mutation on both of
  1664. # the duped arguments; in this case, we need to know what order
  1665. # the metadata mutation applies in. You'll get the correct result
  1666. # otherwise, because a graph that assumes distinct inputs works if
  1667. # you dupe the inputs (the gradient contributions from each input
  1668. # will get summed up appropriately.)
  1669. #
  1670. # TODO: work out how to setup this assert correctly
  1671. """
  1672. assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
  1673. f"there would be {unique_args} distinct arguments"
  1674. )
  1675. """
  1676. return wrapped_compiled_fn(args)
  1677. debugged_compiled_fn._boxed_call = True
  1678. return debugged_compiled_fn
  1679. def describe_input(i, aot_config):
  1680. if i < aot_config.num_params_buffers:
  1681. return f"parameter/buffer {i}"
  1682. else:
  1683. return f"input {i - aot_config.num_params_buffers}"
  1684. # The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic
  1685. # that needs to run after the compiled function.
  1686. #
  1687. # This function accepts a trace_joint flag, indicating whether or not we're generating the runtime
  1688. # epilogue for a forward-only inference graph, or for an autograd.Function.apply function.
  1689. # This is because there are some minor differences in how we treat these cases at runtime:
  1690. # - resize_() is currently handled in the inference case, but not fully handled in the autograd case.
  1691. # - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
  1692. def create_runtime_wrapper(
  1693. compiled_fn,
  1694. *,
  1695. runtime_metadata: CompiledRuntimeMetadata,
  1696. trace_joint: bool,
  1697. keep_input_mutations: bool,
  1698. ):
  1699. if not hasattr(compiled_fn, "_boxed_call"):
  1700. compiled_fn = make_boxed_func(compiled_fn)
  1701. def runtime_wrapper(*args):
  1702. # Step 2: remove aliased inputs that are mutated, replace with synthetic bases
  1703. # Only happens if our graph mutates an input that aliases another input.
  1704. if runtime_metadata.synthetic_base_info is not None:
  1705. # Given: the original args, including at least one pair of inputs that are aliased
  1706. # and get subsequently mutated.
  1707. # Generate: the updated args, including (potentially multiple) synthetic bases
  1708. # that replace the views. The input views are regenerated manually in the compiled function.
  1709. # TODO: think harder about what happens if (a view of) one of these mutated input views is ALSO returned
  1710. new_inputs, metadata = merge_view_inputs(
  1711. args, runtime_metadata.fw_metadata.input_info, is_inference=not trace_joint,
  1712. )
  1713. # We're just re-running the original-args-to-synthetic-base transformation
  1714. # that we ran during compilation.
  1715. # This returns metadata that we use during tracing to recover the input views,
  1716. # which we don't actually need at runtime.
  1717. assert metadata is not None
  1718. args_with_synthetic_bases = new_inputs
  1719. else:
  1720. args_with_synthetic_bases = args
  1721. with torch.autograd._force_original_view_tracking(True):
  1722. all_outs = call_func_with_args(
  1723. compiled_fn,
  1724. args_with_synthetic_bases,
  1725. disable_amp=True,
  1726. )
  1727. num_mutated_inps = runtime_metadata.num_mutated_inputs
  1728. num_metadata_mutated_inps = runtime_metadata.num_mutated_metadata_inputs
  1729. num_intermediate_bases = runtime_metadata.fw_metadata.num_intermediate_bases
  1730. if keep_input_mutations:
  1731. assert (
  1732. len(all_outs)
  1733. == num_metadata_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases
  1734. )
  1735. assert (
  1736. len(runtime_metadata.fw_metadata.mutated_inp_indices) == num_metadata_mutated_inps
  1737. )
  1738. else:
  1739. assert (
  1740. len(all_outs)
  1741. == num_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases
  1742. )
  1743. assert (
  1744. len(runtime_metadata.fw_metadata.mutated_inp_indices) == num_mutated_inps
  1745. )
  1746. # Step 3: After running the compiled fw, apply updates to mutated inputs
  1747. num_mutations_to_apply = len(runtime_metadata.fw_metadata.mutated_inp_indices)
  1748. if num_mutations_to_apply > 0:
  1749. updated_inputs = all_outs[: num_mutations_to_apply]
  1750. fw_outs = all_outs[num_mutations_to_apply :]
  1751. for i, inpt_idx in enumerate(
  1752. runtime_metadata.fw_metadata.mutated_inp_indices
  1753. ):
  1754. meta = runtime_metadata.fw_metadata.input_info[inpt_idx]
  1755. if not meta.mutates_data and not meta.mutates_metadata:
  1756. continue
  1757. original_inpt = args[inpt_idx]
  1758. updated_inpt = updated_inputs[i]
  1759. # TODO: add better resize_() support for autograd case.
  1760. # Check for the case when an input has been resized.
  1761. # Note: One important thing to check for is user code that calls inpt.storage().resize_().
  1762. # We can't trace operations on storage into the graph, so we should get dynamo to graph break.
  1763. # TODO: handle resize_() on inputs to a larger size.
  1764. # This is actually non-trivial to detect, so we should probably just handle it
  1765. # (or make dynamo detect).
  1766. # We can't just check of original_inpt.storage_size != updated_inpt.storage_size,
  1767. # Because the original_inpt might be a view of some larger tensor,
  1768. # and updated_inpt is always densely packed.
  1769. if not trace_joint and original_inpt.storage().size() != updated_inpt.storage().size():
  1770. original_inpt.resize_(updated_inpt.size())
  1771. if meta.mutates_metadata and not meta.mutates_data:
  1772. if trace_joint:
  1773. assert isinstance(updated_inpt, TensorAlias)
  1774. updated_inpt = updated_inpt.alias
  1775. # We need to grab the size/stride/storage_offset from the compiled forward,
  1776. # and use that to mutate the metadata of the input
  1777. original_inpt.as_strided_(
  1778. updated_inpt.size(),
  1779. updated_inpt.stride(),
  1780. updated_inpt.storage_offset(),
  1781. )
  1782. else:
  1783. if meta.mutates_data and meta.mutates_metadata:
  1784. original_inpt.as_strided_(
  1785. updated_inpt.size(),
  1786. updated_inpt.stride(),
  1787. updated_inpt.storage_offset(),
  1788. )
  1789. else:
  1790. assert meta.mutates_data
  1791. original_inpt.copy_(updated_inpt)
  1792. else:
  1793. fw_outs = all_outs
  1794. # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of
  1795. # compiling them.
  1796. if runtime_metadata.num_outputs_aliased > 0:
  1797. # The compiled forward also returned intermediate bases. We don't want to return them to the user.
  1798. if runtime_metadata.fw_metadata.num_intermediate_bases > 0:
  1799. fw_outs_no_intermediate_bases = fw_outs[
  1800. : -runtime_metadata.fw_metadata.num_intermediate_bases
  1801. ]
  1802. intermediate_bases = fw_outs[-runtime_metadata.fw_metadata.num_intermediate_bases:]
  1803. else:
  1804. fw_outs_no_intermediate_bases = fw_outs
  1805. intermediate_bases = []
  1806. assert len(fw_outs_no_intermediate_bases) == len(runtime_metadata.fw_metadata.output_info)
  1807. fw_outs_including_aliases = []
  1808. for i, (o, info) in enumerate(zip(
  1809. fw_outs_no_intermediate_bases, runtime_metadata.fw_metadata.output_info
  1810. )):
  1811. if info.output_type == OutputType.non_alias:
  1812. fw_outs_including_aliases.append(o)
  1813. continue
  1814. if trace_joint:
  1815. assert isinstance(o, TensorAlias)
  1816. o_ = o.alias
  1817. else:
  1818. o_ = o
  1819. o_grad = runtime_metadata.fw_metadata.requires_grad_info[runtime_metadata.num_mutated_inputs + i]
  1820. if info.output_type == OutputType.alias_of_input:
  1821. aliased_base_tensor = args[info.base_idx]
  1822. regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
  1823. fw_outs_including_aliases.append(regenerated_out)
  1824. continue
  1825. elif info.output_type == OutputType.is_input:
  1826. aliased_base_tensor = args[info.base_idx]
  1827. regenerated_out = aliased_base_tensor
  1828. fw_outs_including_aliases.append(regenerated_out)
  1829. continue
  1830. elif info.output_type == OutputType.alias_of_intermediate:
  1831. base_tensor_list = intermediate_bases
  1832. elif info.output_type == OutputType.alias_of_intermediate_save_as_output:
  1833. base_tensor_list = intermediate_bases
  1834. else:
  1835. assert info.output_type == OutputType.alias_of_intermediate_base_is_user_output
  1836. base_tensor_list = fw_outs_no_intermediate_bases
  1837. aliased_base_tensor = base_tensor_list[info.base_idx]
  1838. # TODO: handle the custom autograd function case here.
  1839. # We need a way to check whether a tensor came from a custom autograd fn from python,
  1840. # AND a way to replay that custom view fn.
  1841. regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
  1842. fw_outs_including_aliases.append(regenerated_out)
  1843. return fw_outs_including_aliases
  1844. else:
  1845. return fw_outs
  1846. return runtime_wrapper
  1847. # Has the precondition that there
  1848. # are no duplicate arguments in flat_args (e.g., the same Tensor
  1849. # object never shows up twice. However, two tensor inputs MAY alias
  1850. # the same storage, so long as they have separate TensorImpls.)
  1851. def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
  1852. with enable_python_dispatcher():
  1853. _fw_metadata, out = run_functionalized_fw_and_collect_metadata(
  1854. flat_fn,
  1855. # Note: in the non-inference path, we are currently not passing input mutations into the graph directly.
  1856. # This is mainly difficult due to the partitioner, but we are leaving (a bit of) perf on the table.
  1857. keep_input_mutations=False,
  1858. )(
  1859. *flat_args
  1860. )
  1861. # out here corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
  1862. # It includes outputs of the original forward, *and* any updated inputs due to input mutations.
  1863. # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
  1864. out = pytree.tree_map(
  1865. lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
  1866. out,
  1867. )
  1868. # merge_view_inputs() is used again at runtime to create synthetic bases out of aliased inputs.
  1869. # This code only executes at runtime if we have graph inputs that alias each other, and one of those inputs
  1870. # gets its data mutated.
  1871. # When that happens, we replace the aliased inputs with a synthetic base, and in the traced forward
  1872. # we later generate the input views
  1873. flat_args_with_views_handled, _synthetic_base_info = merge_view_inputs(
  1874. flat_args, _fw_metadata.input_info, is_inference=False,
  1875. )
  1876. # pre-compute, so we can bail out quickly in the hotpath
  1877. metadata_ = CompiledRuntimeMetadata(
  1878. synthetic_base_info=_synthetic_base_info,
  1879. fw_metadata=_fw_metadata,
  1880. )
  1881. assert len(_fw_metadata.requires_grad_info) == metadata_.num_mutated_inputs + metadata_.num_outputs
  1882. joint_forward_backward = create_forward_or_joint_functionalized(
  1883. flat_fn,
  1884. meta=metadata_,
  1885. trace_joint=True,
  1886. # For now in the autograd case, we NEVER keep input mutations (we could eventually fix this for slightly better perf
  1887. # in some cases, but it's annoying to fix the partitioner)
  1888. keep_input_mutations=False,
  1889. )
  1890. joint_inputs = (flat_args_with_views_handled, out)
  1891. disable_amp = torch._C._is_any_autocast_enabled()
  1892. if config.use_functionalize:
  1893. with enable_python_dispatcher():
  1894. flattened_joints, _ = pytree.tree_flatten(joint_inputs)
  1895. fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(
  1896. *joint_inputs
  1897. )
  1898. # There should be *NO* mutating ops in the graph at this point.
  1899. assert_functional_graph(fx_g.graph)
  1900. # Redudant with the check above, but worth having in case tracing introduced
  1901. # a fake tensor. Unlikely.
  1902. # See Note: [Fake Modules and AOTAutograd]
  1903. torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
  1904. fx_g.graph.eliminate_dead_code()
  1905. fx_g.recompile()
  1906. else:
  1907. # joint_forward_backward() now always runs with functionalization, and factoring it out
  1908. # to make that toggleable is a bit painful.
  1909. # aot autograd without functionalization is wrong anyway, so we error.
  1910. raise AssertionError(
  1911. "Graph partitioning without functionalization is not sound, we may introduce errors"
  1912. )
  1913. if config.debug_joint:
  1914. log.debug(f"====== Joint graph {aot_config.aot_id} ======")
  1915. log.debug(fx_g.print_readable(print_output=False))
  1916. with torch.no_grad():
  1917. with track_graph_compiling(aot_config, "joint"):
  1918. num_inner_fwd_outputs = metadata_.num_mutated_inputs + metadata_.num_outputs + _fw_metadata.num_intermediate_bases
  1919. fw_module, bw_module = aot_config.partition_fn(
  1920. fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
  1921. )
  1922. fw_outs = [n for n in fw_module.graph.nodes if n.op == "output"][0].args[0]
  1923. # we only need to bookkeep the symints that are saved for bw, not any symints
  1924. # the user forward might have returned in its own output
  1925. fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
  1926. symint_outs_saved_for_bw = [
  1927. n for n in fw_outs_saved_for_bw if is_sym_node(n)
  1928. ]
  1929. _num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
  1930. if config.debug_graphs:
  1931. log.debug(f"====== Forward graph {aot_config.aot_id} ======")
  1932. log.debug(fw_module.print_readable(print_output=False))
  1933. log.debug(f"====== Backward graph {aot_config.aot_id} ======")
  1934. log.debug(bw_module.print_readable(print_output=False))
  1935. with track_graph_compiling(aot_config, "forward"):
  1936. compiled_fw_func = aot_config.fw_compiler(
  1937. fw_module, flat_args_with_views_handled
  1938. )
  1939. class CompiledFunction(torch.autograd.Function):
  1940. compiled_fw = compiled_fw_func
  1941. compiled_bw = None
  1942. metadata = metadata_
  1943. num_symints_saved_for_bw = _num_symints_saved_for_bw
  1944. @staticmethod
  1945. def forward(ctx, *deduped_flat_tensor_args):
  1946. # There is a pretty complicated calling convention around what the compiled fw returns.
  1947. # The full list of outputs and their relative order is:
  1948. # (*mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
  1949. # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
  1950. # of the original view, and not the synthetic base
  1951. fw_outs = call_func_with_args(
  1952. CompiledFunction.compiled_fw,
  1953. deduped_flat_tensor_args,
  1954. disable_amp=disable_amp,
  1955. )
  1956. num_outputs = CompiledFunction.metadata.num_outputs
  1957. num_outputs_aliased_to_inputs = (
  1958. CompiledFunction.metadata.num_outputs_aliased_to_inputs
  1959. )
  1960. num_outputs_aliased_to_intermediates = (
  1961. CompiledFunction.metadata.num_outputs_aliased_to_intermediates
  1962. )
  1963. num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
  1964. num_intermediate_bases = CompiledFunction.metadata.fw_metadata.num_intermediate_bases
  1965. num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw
  1966. num_mutated_inputs = CompiledFunction.metadata.num_mutated_inputs
  1967. num_mutated_metadata_only_inputs = (
  1968. CompiledFunction.metadata.num_mutated_metadata_only_inputs
  1969. )
  1970. # Our forward() returns both (mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints)
  1971. num_forward_returns = num_mutated_inputs + num_outputs + num_intermediate_bases
  1972. assert num_forward_returns == len(
  1973. CompiledFunction.metadata.fw_metadata.requires_grad_info
  1974. ) + num_intermediate_bases
  1975. # Partitioners must put symint arguments at the end separate from tensor arguments
  1976. if num_symints_saved_for_bw > 0:
  1977. tensors_saved_for_backwards = fw_outs[
  1978. num_forward_returns:-num_symints_saved_for_bw
  1979. ]
  1980. assert all(
  1981. [isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]
  1982. )
  1983. # See Note [Detaching saved tensors in AOTAutograd]
  1984. ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
  1985. symint_outs = fw_outs[-num_symints_saved_for_bw:]
  1986. assert all(
  1987. [
  1988. isinstance(x, (int, float, torch.SymInt, torch.SymFloat))
  1989. for x in symint_outs
  1990. ]
  1991. )
  1992. ctx.symints = symint_outs
  1993. else:
  1994. tensors_saved_for_backwards = fw_outs[num_forward_returns:]
  1995. # See Note [Detaching saved tensors in AOTAutograd]
  1996. ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
  1997. ctx.symints = []
  1998. raw_returns = fw_outs[0:num_forward_returns]
  1999. # Wrap all autograd.Function.forward() outputs that are aliases
  2000. # so that autograd.Function doesn't treat them as tensors
  2001. if num_mutated_metadata_only_inputs > 0:
  2002. for i, idx in enumerate(
  2003. CompiledFunction.metadata.fw_metadata.mutated_inp_indices
  2004. ):
  2005. # We could make this faster by only looping over inputs with metadata-only mutations
  2006. # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many.
  2007. info = CompiledFunction.metadata.fw_metadata.input_info[idx]
  2008. if info.mutates_metadata and not info.mutates_data:
  2009. raw_returns[i] = TensorAlias(raw_returns[i])
  2010. if config.debug_assert:
  2011. user_mutated_inputs_raw = raw_returns[0:num_mutated_inputs]
  2012. mut_inp_infos = [
  2013. x for x in CompiledFunction.metadata.fw_metadata.input_info if x.mutates_data or x.mutates_metadata
  2014. ]
  2015. assert len(user_mutated_inputs_raw) == len(mut_inp_infos)
  2016. if num_outputs_aliased > 0:
  2017. for idx in CompiledFunction.metadata.fw_metadata.aliased_out_indices:
  2018. raw_return_idx = num_mutated_inputs + idx
  2019. raw_returns[raw_return_idx] = TensorAlias(raw_returns[raw_return_idx])
  2020. if config.debug_assert:
  2021. intermediates_raw = raw_returns[num_mutated_inputs + num_outputs:]
  2022. assert not any(isinstance(x, TensorAlias) for x in intermediates_raw)
  2023. # invariant: intermediate bases always require gradients, so we don't have to
  2024. # consider marking them as non-differentiable.
  2025. raw_returns_not_including_intermediate_bases = raw_returns[:num_mutated_inputs + num_outputs]
  2026. fw_outs_not_requiring_grad = [
  2027. x
  2028. for (i, x) in enumerate(raw_returns_not_including_intermediate_bases)
  2029. if isinstance(x, torch.Tensor)
  2030. and not CompiledFunction.metadata.fw_metadata.requires_grad_info[i]
  2031. ]
  2032. ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
  2033. return tuple(raw_returns)
  2034. @staticmethod
  2035. def backward(ctx, *flat_args):
  2036. # Calling convention: we expect a grad_out passed to the backward:
  2037. # - for every output of the fw that does *not* alias an input or graph intermediate
  2038. # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)
  2039. # - for every graph intermediate that we need to use to generate an output later.
  2040. # The other outputs in the autograd.Function.forward that do *not* show up in the backward include:
  2041. # - outputs that alias inputs or graph intermediates
  2042. # - updated inputs due to metadata-only mutations.
  2043. # We need to return them in the forward, but ensure that they all do not get gradients in the backward,
  2044. # and we filter them out here before passing the remaining grad_outputs into the compiled backward.
  2045. num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs
  2046. num_intermediate_bases = CompiledFunction.metadata.fw_metadata.num_intermediate_bases
  2047. expected_grad_outs = (
  2048. CompiledFunction.metadata.num_outputs + num_mutated_inps + num_intermediate_bases
  2049. )
  2050. assert len(flat_args) == expected_grad_outs
  2051. if (
  2052. CompiledFunction.metadata.num_mutated_metadata_only_inputs > 0
  2053. or CompiledFunction.metadata.num_outputs_aliased > 0
  2054. ):
  2055. inp_tangents, out_tangents, intermediate_base_tangents = (
  2056. flat_args[0:num_mutated_inps],
  2057. flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs],
  2058. flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:],
  2059. )
  2060. # input_info contains info on *every* input,
  2061. # But in the backward(), we are only given grad outputs for every mutated input.
  2062. # We then need to filter out the grad outputs that correspond to metadata-only mutations.
  2063. mutated_inp_indices = CompiledFunction.metadata.fw_metadata.mutated_inp_indices
  2064. input_info = CompiledFunction.metadata.fw_metadata.input_info
  2065. assert len(inp_tangents) == len(mutated_inp_indices)
  2066. inp_tangents_filtered = [
  2067. x
  2068. for x, info_idx in zip(inp_tangents, mutated_inp_indices)
  2069. if input_info[info_idx].mutates_data
  2070. ]
  2071. # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
  2072. out_info = CompiledFunction.metadata.fw_metadata.output_info
  2073. out_tangents_filtered = [
  2074. x
  2075. for x, info in zip(out_tangents, out_info)
  2076. if info.output_type == OutputType.non_alias
  2077. ]
  2078. # intermediate bases always require gradients, and always participate in the backward graph.
  2079. flat_bw_args = itertools.chain(inp_tangents_filtered, out_tangents_filtered, intermediate_base_tangents)
  2080. # sanity asserts
  2081. # metadata_only_inps = [
  2082. # x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
  2083. # if not input_info[info_idx].mutates_data
  2084. # ]
  2085. # aliased_outputs = [
  2086. # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
  2087. # assert all(x is None for x in metadata_only_inps)
  2088. # assert all(x is None for x in aliased_outputs)
  2089. else:
  2090. flat_bw_args = flat_args
  2091. contiguous_args = [
  2092. t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args
  2093. ]
  2094. all_args = (
  2095. list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args)
  2096. )
  2097. del contiguous_args
  2098. def call_compiled_backward():
  2099. if CompiledFunction.compiled_bw is None:
  2100. # TODO - pass in fake tensors ?
  2101. context = disable_autocast_manager if disable_amp else nullcontext
  2102. with context(), track_graph_compiling(aot_config, "backward"):
  2103. CompiledFunction.compiled_bw = aot_config.bw_compiler(
  2104. bw_module, all_args
  2105. )
  2106. ctx.maybe_clear_saved_tensors()
  2107. out = call_func_with_args(
  2108. CompiledFunction.compiled_bw,
  2109. all_args,
  2110. steal_args=True,
  2111. disable_amp=disable_amp,
  2112. )
  2113. return tuple(out)
  2114. if torch.is_grad_enabled() and any(t.requires_grad for t in all_args if isinstance(t, torch.Tensor)):
  2115. # Ensure that the graph is connected, and error if double backward is performed.
  2116. # See comment for why once_differentiable is not sufficient:
  2117. # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
  2118. class CompiledFunctionBackward(torch.autograd.Function):
  2119. @staticmethod
  2120. def forward(ctx, *unused_args):
  2121. return call_compiled_backward()
  2122. @staticmethod
  2123. def backward(ctx, *args):
  2124. raise RuntimeError("torch.compile with aot_autograd does not currently support double backward")
  2125. # Pass args even though they're unused, so that the graph is built
  2126. out = CompiledFunctionBackward.apply(*all_args)
  2127. else:
  2128. out = call_compiled_backward()
  2129. return out
  2130. compiled_function = create_runtime_wrapper(
  2131. CompiledFunction.apply,
  2132. runtime_metadata=metadata_,
  2133. trace_joint=True,
  2134. keep_input_mutations=False,
  2135. )
  2136. if not config.debug_assert:
  2137. return compiled_function
  2138. flat_requires_grad = [
  2139. a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
  2140. ]
  2141. @wraps(compiled_function)
  2142. def debug_compiled_function(*args):
  2143. # TODO: Check aliasing relationships
  2144. # TODO: Check strides for metadata mutation
  2145. # (NB: ideally, this logic is factored out of this function and
  2146. # you move these debug checks there)
  2147. # Check requires grad. Bad case is when we compiled with
  2148. # requires_grad = False, but input requires_grad = True
  2149. # (vice versa is OK; we compute a gradient and then throw
  2150. # it away when it hits the input.)
  2151. for i, a in enumerate(args):
  2152. can_require_grad = flat_requires_grad[i]
  2153. if can_require_grad is None:
  2154. assert not isinstance(a, Tensor)
  2155. elif not can_require_grad:
  2156. assert not a.requires_grad, format_guard_bug_msg(
  2157. aot_config,
  2158. f"{describe_input(i, aot_config)} would not require grad",
  2159. )
  2160. return compiled_function(*args)
  2161. return debug_compiled_function
  2162. @dynamo_timed
  2163. def create_aot_dispatcher_function(
  2164. flat_fn, flat_args: List[Any], aot_config: AOTConfig
  2165. ):
  2166. """
  2167. Traces the forward and backward graphs of the attr:`flat_fn` to generate a
  2168. joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
  2169. the tracing mechanism to understand the graph capturing details.
  2170. The joint graph is then passed through attr:`partition_fn` to isolate the
  2171. forward and backward portions, which are then respectively compiled via the
  2172. provided attr:`fw_compiler` and attr:`bw_compiler`.
  2173. The resulting compiled forward and backward graphs are then wrapped up in a
  2174. ``torch.autograd.Function`` object.
  2175. The calling convention here is that the first aot_config.num_params_buffers
  2176. inputs in flat_args are parameters and buffers, and the rest are inputs.
  2177. We use this to assume that parameters/buffer's shapes don't change.
  2178. """
  2179. # This is the main entry point.
  2180. # TODO: Chillee argues that dynamo itself should pass in fake tensors to
  2181. # the list of arguments when compiling; at the moment we do not do this
  2182. if aot_config.decompositions is None:
  2183. aot_config.decompositions = {}
  2184. aot_config.decompositions = {
  2185. **aot_autograd_decompositions,
  2186. **aot_config.decompositions,
  2187. }
  2188. log.setLevel(config.log_level)
  2189. # NB: don't bother setting allow_fallback_kernels; this should not actually
  2190. # be configurable in fake tensor, we should automatically do the right
  2191. # thing
  2192. if config.debug_fake_cross_ref:
  2193. # This is a little messy but TorchDynamo directly changes `use_fake_tensor`
  2194. # so it's not enough for user to change the config manually
  2195. # TODO: have TorchDynamo read in `use_fake_tensor` from os environ /
  2196. # coordinate flags
  2197. config.use_fake_tensor = False
  2198. # Check flat_args to see if they're already fake. If so, use that fake
  2199. # mode instead.
  2200. for x in flat_args:
  2201. if isinstance(x, FakeTensor):
  2202. fake_mode = x.fake_mode
  2203. shape_env = fake_mode.shape_env
  2204. break
  2205. else:
  2206. shape_env = ShapeEnv() if config.use_dynamic_shapes else None
  2207. fake_mode = (
  2208. FakeTensorMode(shape_env=shape_env)
  2209. if config.use_fake_tensor
  2210. else nullcontext()
  2211. )
  2212. cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext()
  2213. python_dispatcher_mode = (
  2214. enable_python_dispatcher() if shape_env is not None else nullcontext()
  2215. )
  2216. with torch.autograd.set_multithreading_enabled(
  2217. False
  2218. ), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
  2219. def process_inputs(flat_args):
  2220. if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode):
  2221. def convert(idx, x):
  2222. if not isinstance(x, torch.Tensor):
  2223. return x
  2224. if isinstance(x, FakeTensor):
  2225. assert x.fake_mode is fake_mode
  2226. return x
  2227. if (
  2228. idx < aot_config.num_params_buffers
  2229. and config.static_weight_shapes
  2230. ):
  2231. return fake_mode.from_tensor(x, static_shapes=True)
  2232. return fake_mode.from_tensor(x, static_shapes=False)
  2233. return [convert(idx, x) for idx, x in enumerate(flat_args)]
  2234. else:
  2235. return flat_args
  2236. fake_flat_args = process_inputs(flat_args)
  2237. needs_autograd = (
  2238. any([x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)])
  2239. and torch.is_grad_enabled()
  2240. )
  2241. # crappy version of dispatcher
  2242. # TODO: Do this properly
  2243. if needs_autograd:
  2244. compiler_fn = aot_dispatch_autograd
  2245. else:
  2246. compiler_fn = aot_dispatch_base
  2247. compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)
  2248. # You can put more passes here
  2249. compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  2250. if not hasattr(compiled_fn, "_boxed_call"):
  2251. compiled_fn = make_boxed_func(compiled_fn)
  2252. return compiled_fn
  2253. # Inspired by autodidax (thanks!)
  2254. class PytreeThunk:
  2255. spec = None
  2256. # These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
  2257. is_simple = (
  2258. None # if the output spec is a tuple/list, we won't bother unflattening it.
  2259. )
  2260. is_really_simple = None # if the output spec is a LeafSpec
  2261. def set(self, spec):
  2262. assert self.spec is None or self.spec == spec
  2263. self.spec = spec
  2264. if type(self.spec) in [tuple, list] and all(
  2265. isinstance(i, pytree.LeafSpec) for i in spec.children_specs
  2266. ):
  2267. self.is_simple = True
  2268. if isinstance(self.spec, pytree.LeafSpec):
  2269. self.is_really_simple = True
  2270. def unflatten(self, x):
  2271. if self.is_really_simple:
  2272. return x[0]
  2273. if self.is_simple:
  2274. return x
  2275. return pytree.tree_unflatten(x, self.spec)
  2276. def aot_function(
  2277. fn: Callable,
  2278. fw_compiler: Callable,
  2279. bw_compiler: Optional[Callable] = None,
  2280. partition_fn: Callable = default_partition,
  2281. decompositions: Optional[Dict] = None,
  2282. num_params_buffers: int = 0,
  2283. hasher_type=None, # deprecated
  2284. static_argnums: Optional[Tuple[int]] = None, # deprecated
  2285. keep_inference_input_mutations: bool = False
  2286. ) -> Callable:
  2287. """
  2288. Traces the forward and backward graph of :attr:`fn` using torch dispatch
  2289. mechanism, and then compiles the generated forward and backward graphs
  2290. through :attr:`fw_compiler` and :attr:`bw_compiler`.
  2291. :func:`aot_function` traces the forward and backward graph ahead of time,
  2292. and generates a joint forward and backward graph. :attr:`partition_fn` is
  2293. then used to separate out forward and backward graphs. The partitioner
  2294. function can be used to perform optimizations such as recomputation. One can
  2295. set `decompositions` dictionary to decompose the operators into a sequence
  2296. of core or simpler operators supported by the backend compilers.
  2297. :func:`aot_function` uses a compilation cache, based on input tensor
  2298. properties, to detect when there is a need of recompilation.
  2299. .. warning::
  2300. This API is experimental and likely to change.
  2301. Args:
  2302. fn (Callable): A Python function that takes one ore more arguments. Must
  2303. return one or more Tensors.
  2304. fw_compiler (Callable): A Python function that accepts an Fx graph with
  2305. Aten ops and input args, and returns a Callable that semantically is
  2306. equivalent to the input Fx graph.
  2307. bw_compiler (Optional[Callable]): A Python function that accepts an
  2308. Fx graph with Aten ops and input args, and returns a Callable that
  2309. semantically is equivalent to the input Fx graph. Default: None
  2310. (when None, it defaults to the :attr:`fw_compiler`)
  2311. partition_fn (Callable): A Python function that takes a joint forward
  2312. and backward graph, and partitions it into separate forward and
  2313. backward graphs.
  2314. decompositions (Dict): A dictionary to define the decomposition of
  2315. larger Aten ops into simpler or core Aten ops.
  2316. Returns:
  2317. Returns a ``Callable`` that retains the eager behavior of the original
  2318. :attr:`fn`, but with forward and backward graph compiled via
  2319. :attr:`fw_compile` and :attr:`bw_compile`.
  2320. A simple example usage of :func:`aot_function` is as follows. This example
  2321. will print the forward and backward graphs of the function ``fn``
  2322. >>> fn = lambda x : x.sin().cos()
  2323. >>> def print_compile_fn(fx_module, args):
  2324. >>> print(fx_module)
  2325. >>> return fx_module
  2326. >>> aot_fn = aot_function(fn, print_compile_fn)
  2327. >>> x = torch.randn(4, 5, requires_grad=True)
  2328. >>> aot_fn(x)
  2329. """
  2330. if static_argnums is not None:
  2331. raise RuntimeError(
  2332. "static_argnums has been deprecated - manually wrap your function or use torchdynamo."
  2333. )
  2334. if bw_compiler is None:
  2335. bw_compiler = fw_compiler
  2336. aot_config = AOTConfig(
  2337. fw_compiler=fw_compiler,
  2338. bw_compiler=bw_compiler,
  2339. partition_fn=partition_fn,
  2340. decompositions=decompositions,
  2341. num_params_buffers=num_params_buffers,
  2342. aot_id=next(AOT_COUNTER),
  2343. keep_inference_input_mutations=keep_inference_input_mutations
  2344. )
  2345. cached_res = None
  2346. @wraps(fn)
  2347. def returned_function(*args, **kwargs):
  2348. nonlocal cached_res
  2349. # Now flatten the tensor args
  2350. flat_args, _ = pytree.tree_flatten((args, kwargs))
  2351. # Compile the function and save it in the cache
  2352. if cached_res is None:
  2353. # Save the args_spec for flat_tensor_args to unflatten while tracing
  2354. _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
  2355. out_spec = PytreeThunk()
  2356. def flat_fn(*flat_args):
  2357. # The input are flattened tensor args. Prepare the args in the
  2358. # order that original function expects. Add static args as well.
  2359. # They will appear as tensor constants in the traced graph.
  2360. nonlocal out_spec
  2361. args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
  2362. tree_out = fn(*args, **kwargs)
  2363. flat_out, spec = pytree.tree_flatten(tree_out)
  2364. for i in flat_out:
  2365. is_known_type = False
  2366. for j in KNOWN_TYPES:
  2367. if isinstance(i, j):
  2368. is_known_type = True
  2369. break
  2370. if not is_known_type:
  2371. raise RuntimeError(
  2372. f"Found {type(i)} in output, which is not a known type. "
  2373. "If this type holds tensors, you need to register a pytree for it. "
  2374. "See https://github.com/pytorch/functorch/issues/475 for a brief "
  2375. "explanation why. If you don't need to register a pytree, please "
  2376. "leave a comment explaining your use case and we'll make this more "
  2377. "ergonomic to deal with"
  2378. )
  2379. out_spec.set(spec)
  2380. return flat_out
  2381. compiled_fn = create_aot_dispatcher_function(
  2382. flat_fn,
  2383. flat_args,
  2384. aot_config,
  2385. )
  2386. cached_res = (compiled_fn, out_spec)
  2387. cached_fn, out_spec = cached_res
  2388. out = cached_fn(flat_args)
  2389. return out_spec.unflatten(out)
  2390. return returned_function
  2391. def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
  2392. """
  2393. Traces the forward and backward graph of :attr:`mod` using torch dispatch
  2394. tracing mechanism. It is wrapper function, that underneath uses
  2395. :func:`aot_function` to perform tracing and compilation.
  2396. :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs
  2397. to a new callable which is then compiled through :func:`aot_function`.
  2398. .. warning::
  2399. This API is experimental and likely to change.
  2400. Args:
  2401. mod (Callable): A ``nn.Module`` module.
  2402. args : args to be passed to :func:`aot_function`
  2403. kwargs : kwargs to be passed to :func:`aot_function`
  2404. Returns:
  2405. Returns a ``nn.Module`` that retains the eager behavior of the original
  2406. :attr:`mod`, but with forward and backward graph compiled.
  2407. """
  2408. # See Note: [Fake Modules and AOTAutograd]
  2409. torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
  2410. def functional_call(named_params, named_buffers, *args, **kwargs):
  2411. params_and_buffers = {**named_params, **named_buffers}
  2412. return torch.func.functional_call(mod, params_and_buffers, args, kwargs)
  2413. named_params = dict(mod.named_parameters(remove_duplicate=False))
  2414. named_buffers = dict(mod.named_buffers(remove_duplicate=False))
  2415. num_params_buffers = len(named_params) + len(named_buffers)
  2416. compiled_f = aot_function(
  2417. functional_call, num_params_buffers=num_params_buffers, *args, **kwargs
  2418. )
  2419. class AOTModule(nn.Module):
  2420. def __init__(self):
  2421. super().__init__()
  2422. self.orig_module = mod
  2423. def forward(self, *args, **kwargs):
  2424. return compiled_f(
  2425. named_params,
  2426. named_buffers,
  2427. *args,
  2428. **kwargs,
  2429. )
  2430. return AOTModule()
  2431. def aot_module_simplified(
  2432. mod: nn.Module,
  2433. args,
  2434. fw_compiler: Callable,
  2435. bw_compiler: Optional[Callable] = None,
  2436. partition_fn: Callable = default_partition,
  2437. decompositions: Optional[Dict] = None,
  2438. hasher_type=None,
  2439. static_argnums=None,
  2440. keep_inference_input_mutations=False,
  2441. ) -> nn.Module:
  2442. """
  2443. This is the simplified or low overhead version of aot_module. For frontends
  2444. like TorchDynamo, the input functions/modules to AOT are static and have
  2445. unpacked inputs/outputs. This gives us an opportunity to remove the
  2446. (1) pytree overhead to parse inputs/outputs,
  2447. (2) AOT Autograd cache,
  2448. (3) Reading of params/buffers in every forward call
  2449. :func:`aot_module_simplified` removes these overheads.
  2450. """
  2451. #########################################################
  2452. # Redudant with dynamo, but worth having in case this gets invoked elsewhere.
  2453. # Note [Fake Modules and AOTAutograd]
  2454. #
  2455. # A simple heuristic for when to use fake versus real tensors is that fake tensors are for compile time
  2456. # (when we don't want to actually run the compute, but we do want to know about metadata),
  2457. # and real tensors are for runtime (when we actually want to do the compute.) However, in AOTAutograd,
  2458. # modules are the exception: we always pass AOTAutograd modules with real tensors.
  2459. # This is because AOTAutograd will produce a compiled function which needs to directly access any
  2460. # parameters the compiled function may need, but these parameters will NOT be passed in by the caller (aka Dynamo).
  2461. # So at compile time, the compiled function we produce must close over any parameters, and those parameters must be
  2462. # real parameters, and we cannot do this unless at compile time we get a module with real tensors.
  2463. # Even if Dynamo did pass all parameters explicitly at runtime, which would eliminate the need to close over
  2464. # the parameters, it would still be profitable to pass real tensor parameters to the compiler at compile time,
  2465. # because some compilation strategies like CUDA graphs want to burn in the pointer addresses where the parameter data live,
  2466. # and of course we can't do that unless we give the backend a real tensor.
  2467. torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
  2468. params = {
  2469. **dict(mod.named_parameters(remove_duplicate=False)),
  2470. **dict(mod.named_buffers(remove_duplicate=False)),
  2471. }
  2472. params_flat, params_spec = pytree.tree_flatten(params)
  2473. params_flat = tuple(params_flat)
  2474. params_len = len(params_flat)
  2475. def functional_call(*args, **kwargs):
  2476. with stateless._reparametrize_module(
  2477. mod, pytree.tree_unflatten(args[:params_len], params_spec)
  2478. ):
  2479. if isinstance(mod, torch.fx.GraphModule):
  2480. with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
  2481. warnings.filterwarnings(
  2482. "ignore", "Anomaly Detection has been enabled."
  2483. )
  2484. with torch.autograd.detect_anomaly(check_nan=False):
  2485. out = Interpreter(mod).run(*args[params_len:], **kwargs)
  2486. else:
  2487. out = mod(*args[params_len:], **kwargs)
  2488. if not isinstance(out, (tuple, list)):
  2489. raise RuntimeError(
  2490. "Graph output must be a tuple(). This is so that we can avoid "
  2491. "pytree processing of the ouputs. Please change the module to "
  2492. "have tuple outputs or use aot_module instead."
  2493. )
  2494. return out
  2495. assert static_argnums is None
  2496. if bw_compiler is None:
  2497. bw_compiler = fw_compiler
  2498. aot_config = AOTConfig(
  2499. fw_compiler=fw_compiler,
  2500. bw_compiler=bw_compiler,
  2501. partition_fn=partition_fn,
  2502. decompositions=decompositions,
  2503. num_params_buffers=params_len,
  2504. aot_id=next(AOT_COUNTER),
  2505. keep_inference_input_mutations=keep_inference_input_mutations,
  2506. )
  2507. full_args = []
  2508. full_args.extend(params_flat)
  2509. full_args.extend(args)
  2510. compiled_fn = create_aot_dispatcher_function(
  2511. functional_call,
  2512. full_args,
  2513. aot_config,
  2514. )
  2515. # TODO: There is something deeply wrong here; compiled_fn running with
  2516. # the boxed calling convention, but aot_module_simplified somehow
  2517. # historically returned a function that was not the boxed calling
  2518. # convention. This should get fixed...
  2519. def forward(*runtime_args):
  2520. full_args = []
  2521. full_args.extend(params_flat)
  2522. full_args.extend(runtime_args)
  2523. return compiled_fn(full_args)
  2524. # Just for convenience
  2525. forward.zero_grad = mod.zero_grad
  2526. forward.named_parameters = mod.named_parameters
  2527. forward.named_buffers = mod.named_buffers
  2528. return forward
  2529. compiled_function = aot_function
  2530. compiled_module = aot_module