pattern_matcher.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. import dataclasses
  2. import functools
  3. import inspect
  4. import itertools
  5. import logging
  6. import operator
  7. import os
  8. from collections import defaultdict
  9. from typing import Any, Callable, List, Union
  10. import torch
  11. import torch._inductor as inductor
  12. import torch.fx
  13. import torch.utils._pytree as pytree
  14. from torch._dynamo.utils import counters
  15. from torch.fx.immutable_collections import immutable_dict, immutable_list
  16. from . import config, ir
  17. from .lowering import lowerings as L
  18. from .virtualized import V
  19. log = logging.getLogger(__name__)
  20. aten = torch.ops.aten
  21. Constant = Any
  22. NodeOrConstant = Union[Constant, torch.fx.Node]
  23. class Match:
  24. """
  25. Represents a successfully matched pattern.
  26. """
  27. def __init__(self, pattern, args=None, kwargs=None):
  28. super().__init__()
  29. self.pattern = pattern
  30. # The input nodes that must be passed in to the result
  31. self.args = args or []
  32. self.kwargs = kwargs or {}
  33. # The nodes matched in this expression
  34. self.nodes = []
  35. # Mapping CallFunction to the node.target
  36. self.targets = {}
  37. def extend(self, other):
  38. if self.kwargs:
  39. for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
  40. if self.kwargs[key] != other.kwargs[key]:
  41. raise FailedMatch(f"kwarg mismatch: {key}")
  42. self.args.extend(other.args)
  43. self.nodes.extend(other.nodes)
  44. self.kwargs.update(other.kwargs)
  45. self.targets.update(other.targets)
  46. def bundle(self):
  47. # Wrap args in an extra list
  48. self.args = [tuple(self.args)]
  49. return self
  50. def __repr__(self):
  51. return f"Match(..., {self.args}, {self.kwargs})"
  52. def erase_nodes(self, graph: torch.fx.Graph):
  53. for n in reversed(self.nodes):
  54. graph.erase_node(n)
  55. class FailedMatch(RuntimeError):
  56. def __bool__(self):
  57. return False
  58. class MatchContext:
  59. """
  60. State needed while running PatternExpr._match().
  61. """
  62. def __init__(self, outputs: List["PatternExpr"]):
  63. self.outputs = outputs
  64. self.pattern_to_node = {}
  65. def match(self, pattern, node):
  66. """wrapper to check reused nodes in patterns"""
  67. if pattern in self.pattern_to_node:
  68. if self.pattern_to_node[pattern] == node:
  69. return Match(pattern) # already checked this node
  70. else:
  71. return FailedMatch("repeated pattern differs")
  72. m = pattern._match(node, self)
  73. assert pattern not in self.pattern_to_node
  74. self.pattern_to_node[pattern] = node if m else None
  75. return m
  76. class PatternExpr:
  77. """
  78. Base class for types of patterns
  79. """
  80. def _match(self, node: torch.fx.Node, outputs) -> Union[Match, FailedMatch]:
  81. raise NotImplementedError()
  82. def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
  83. try:
  84. return MatchContext([self]).match(self, node)
  85. except FailedMatch as e:
  86. return e
  87. def __repr__(self):
  88. return self.__class__.__name__ + "()"
  89. class Arg(PatternExpr):
  90. """
  91. Capture an arg which will become an input to the handler. Args are
  92. passed in depth first order.
  93. """
  94. def _match(self, node: NodeOrConstant, ctx: MatchContext):
  95. return Match(self, args=[node]) # matches anything
  96. class KeywordArg(PatternExpr):
  97. """
  98. Capture a kwarg which will become an input to the handler.
  99. """
  100. def __init__(self, name):
  101. super().__init__()
  102. self.name = name
  103. def _match(self, node: NodeOrConstant, ctx: MatchContext):
  104. return Match(self, kwargs={self.name: node}) # matches anything
  105. class CallFunction(PatternExpr):
  106. """
  107. Matches a call_function node in the FX graps: `fns[i](*args, **kwargs)`
  108. """
  109. def __init__(self, fns, *args, _users=1, **kwargs):
  110. super().__init__()
  111. fns = [fns] if callable(fns) else list(fns)
  112. for fn in list(fns):
  113. if isinstance(fn, torch._ops.OpOverloadPacket):
  114. fns.extend([getattr(fn, overload) for overload in fn.overloads()])
  115. self.fns = fns
  116. self.fns_set = set(fns)
  117. self.args = tuple(args)
  118. self.kwargs = dict(kwargs)
  119. self.users = _users
  120. if any(
  121. isinstance(x, (dict, list, tuple))
  122. for x in itertools.chain(args, kwargs.values())
  123. ):
  124. self.flatten = self.pytree_flatten
  125. else:
  126. self.flatten = self.simple_flatten
  127. self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
  128. @staticmethod
  129. def simple_flatten(args, kwargs):
  130. return (*args, *kwargs.values()), (len(args), *kwargs.keys())
  131. @staticmethod
  132. def pytree_flatten(args, kwargs):
  133. def norm_spec(s: pytree.TreeSpec):
  134. if s.type is None:
  135. return s
  136. mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
  137. return pytree.TreeSpec(
  138. mapping.get(s.type, s.type),
  139. s.context,
  140. list(map(norm_spec, s.children_specs)),
  141. )
  142. flat, spec = pytree.tree_flatten([args, kwargs])
  143. spec = norm_spec(spec)
  144. return flat, spec
  145. def __repr__(self):
  146. args = [
  147. f"[{self.fns[0].__name__}, ...]",
  148. *map(repr, self.args),
  149. *[f"{k}={v}" for k, v in self.kwargs.items()],
  150. ]
  151. return f"{self.__class__.__name__}({', '.join(args)})"
  152. def _match(self, node: torch.fx.Node, ctx: MatchContext):
  153. if (
  154. not isinstance(node, torch.fx.Node)
  155. or node.op != "call_function"
  156. or node.target not in self.fns_set
  157. or len(node.args) != len(self.args)
  158. or len(node.kwargs) != len(self.kwargs)
  159. ):
  160. return FailedMatch("function_mismatch")
  161. if self not in ctx.outputs and len(node.users) != self.users:
  162. return FailedMatch("multiple_users")
  163. node_items, node_spec = self.flatten(node.args, node.kwargs)
  164. self_items, self_spec = self.flat_args_kwargs
  165. if node_spec != self_spec:
  166. return FailedMatch(f"args_stucture {node_spec} {self_spec}")
  167. assert len(node_items) == len(self_items)
  168. m = Match(self)
  169. for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
  170. if isinstance(pattern, PatternExpr):
  171. child_match = ctx.match(pattern, child_node)
  172. if not child_match:
  173. return FailedMatch(f"arg[{i}]: {child_match}")
  174. m.extend(child_match)
  175. elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
  176. return FailedMatch("constant_args")
  177. m.nodes.append(node)
  178. m.targets[self] = node.target
  179. return m
  180. class ListOf(PatternExpr):
  181. """
  182. Matches a repeated pattern
  183. """
  184. def __init__(self, pattern):
  185. super().__init__()
  186. assert isinstance(pattern, PatternExpr)
  187. self.pattern = pattern
  188. def __repr__(self):
  189. return f"{self.__class__.__name__}({self.pattern})"
  190. def _match(self, node: List[torch.fx.Node], ctx: MatchContext):
  191. if not isinstance(node, (list, tuple)) or len(node) == 0:
  192. return FailedMatch("non_list")
  193. m = Match(self)
  194. for i, child_node in enumerate(node):
  195. child_match = MatchContext(ctx.outputs).match(self.pattern, child_node)
  196. if not child_match:
  197. return FailedMatch(f"list[{i}]: {child_match}")
  198. m.extend(child_match.bundle())
  199. return m.bundle()
  200. pass_patterns = [
  201. defaultdict(list),
  202. defaultdict(list),
  203. defaultdict(list),
  204. ]
  205. @dataclasses.dataclass
  206. class PatternEntry:
  207. pattern: PatternExpr
  208. extra_check: Callable[[Match], bool]
  209. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
  210. raise NotImplementedError()
  211. def register(self, pass_number, target):
  212. if isinstance(pass_number, int):
  213. pass_patterns[pass_number][target].append(self)
  214. else:
  215. for x in pass_number:
  216. self.register(x, target)
  217. @dataclasses.dataclass
  218. class LoweringPatternEntry(PatternEntry):
  219. handler: Any
  220. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
  221. handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
  222. with graph.inserting_before(node):
  223. replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
  224. replacement.meta.update(node.meta)
  225. node.replace_all_uses_with(replacement)
  226. assert match.nodes[-1] is node
  227. match.erase_nodes(graph)
  228. @dataclasses.dataclass
  229. class ReplacementPatternEntry(PatternEntry):
  230. replacement_graph: torch.fx.GraphModule
  231. signature: inspect.Signature
  232. propagate: bool = False
  233. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
  234. class Replacer(torch.fx.Interpreter):
  235. call_method = None
  236. call_module = None
  237. get_attr = None
  238. def call_function(self, target, args, kwargs):
  239. result = graph.call_function(target, args, kwargs)
  240. if propagate and V.fake_mode:
  241. fargs, fkwargs = torch.fx.map_arg(
  242. (args, kwargs), lambda n: n.meta["val"]
  243. )
  244. with V.fake_mode:
  245. result.meta["val"] = target(*fargs, **fkwargs)
  246. return result
  247. propagate = self.propagate
  248. norm_args = self.signature.bind(*match.args, **match.kwargs)
  249. with graph.inserting_before(node):
  250. replacement = Replacer(self.replacement_graph).run(
  251. *norm_args.arguments.values()
  252. )
  253. replacement.meta.update(node.meta)
  254. node.replace_all_uses_with(replacement)
  255. assert match.nodes[-1] is node
  256. match.erase_nodes(graph)
  257. def _return_true(match):
  258. return True
  259. def register_replacement_pattern(pattern, extra_check=_return_true, pass_number=1):
  260. """
  261. Register an aten to aten replacement pattern
  262. """
  263. def decorator(handler):
  264. signature = inspect.signature(handler)
  265. replacement_graph = torch.fx.symbolic_trace(handler)
  266. for target in pattern.fns:
  267. ReplacementPatternEntry(
  268. pattern=pattern,
  269. extra_check=extra_check,
  270. replacement_graph=replacement_graph,
  271. signature=signature,
  272. ).register(pass_number, target)
  273. return handler
  274. assert isinstance(pattern, CallFunction)
  275. return decorator
  276. def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1):
  277. """
  278. Register an aten to inductor IR replacement pattern
  279. """
  280. def decorator(handler):
  281. assert callable(handler)
  282. for target in pattern.fns:
  283. LoweringPatternEntry(
  284. pattern=pattern, extra_check=extra_check, handler=handler
  285. ).register(pass_number, target)
  286. handler._inductor_lowering_function = True
  287. return handler
  288. assert isinstance(pattern, CallFunction)
  289. return decorator
  290. register_pattern = register_lowering_pattern
  291. def replace_matched_patterns(graph: torch.fx.Graph):
  292. # the actual replacement work
  293. for patterns in pass_patterns:
  294. if not patterns:
  295. continue
  296. for node in reversed(graph.nodes):
  297. if node.op == "call_function" and node.target in patterns:
  298. for entry in patterns[node.target]:
  299. if node._erased:
  300. break
  301. m = entry.pattern.match(node)
  302. if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
  303. log.warning(f"{node}{node.args} {m} {entry.pattern}")
  304. if m and entry.extra_check(m):
  305. entry.apply(m, graph, node)
  306. counters["inductor"]["pattern_matcher_count"] += 1
  307. counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
  308. def reorder_for_locality(graph: torch.fx.Graph):
  309. def visit(other_node):
  310. if (
  311. other_node.op == "call_function"
  312. and other_node.target != operator.getitem
  313. and all((n in seen_nodes) for n in other_node.users)
  314. ):
  315. # move node's producers right before it
  316. node.prepend(other_node)
  317. seen_nodes = set()
  318. for node in reversed(graph.nodes):
  319. seen_nodes.add(node)
  320. torch.fx.map_arg((node.args, node.kwargs), visit)
  321. def fx_passes(gm: torch.fx.GraphModule):
  322. if config.dce:
  323. # has some issues with mutation in inference mode
  324. gm.graph.eliminate_dead_code()
  325. if config.reordering:
  326. # has some issues with mutation in inference mode
  327. reorder_for_locality(gm.graph)
  328. if config.pattern_matcher:
  329. replace_matched_patterns(gm.graph)
  330. gm.graph.lint()
  331. ################################################################################
  332. # Actual patterns below this point.
  333. # Priority of patterns is:
  334. # - later output nodes first
  335. # - order patterns are defined in
  336. ################################################################################
  337. @register_lowering_pattern(
  338. CallFunction(
  339. aten.add,
  340. CallFunction(aten.mm, Arg(), Arg()),
  341. CallFunction(aten.mm, Arg(), Arg()),
  342. )
  343. )
  344. def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
  345. return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
  346. @register_lowering_pattern(
  347. CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
  348. )
  349. def cat_mm(match, inputs, dim):
  350. def shape_of(a, b):
  351. m, _ = a.get_size()
  352. _, n = b.get_size()
  353. return [m, n]
  354. return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of)
  355. @register_lowering_pattern(
  356. CallFunction(
  357. aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg()
  358. ),
  359. )
  360. def cat_addmm(match, inputs, dim):
  361. def shape_of(bias, a, b):
  362. m, _ = a.get_size()
  363. _, n = b.get_size()
  364. return [m, n]
  365. return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of)
  366. def cat_tuned_op(match, inputs, dim, *, op, shape_of):
  367. """
  368. Memory planning to remove cat. We can't use the stock memory
  369. planner since autotuning matmauls needs to know the output layout.
  370. """
  371. # TODO(jansel): rewrite this as a bmm?
  372. if dim < 0:
  373. dim += len(shape_of(*inputs[0]))
  374. assert dim in (0, 1)
  375. notdim = 1 - dim
  376. new_size = None
  377. offsets_start = []
  378. offsets_end = []
  379. # compute output sizes
  380. for i in range(len(inputs)):
  381. shape = shape_of(*inputs[i])
  382. if new_size is None:
  383. new_size = shape
  384. else:
  385. new_size[notdim] = V.graph.sizevars.guard_equals(
  386. shape[notdim], new_size[notdim]
  387. )
  388. new_size[dim] += shape[dim]
  389. offsets_start.append(new_size[dim] - shape[dim])
  390. offsets_end.append(new_size[dim])
  391. dtype = functools.reduce(
  392. torch.promote_types, [x.get_dtype() for x in itertools.chain(*inputs)]
  393. )
  394. device = inputs[0][0].get_device()
  395. kernel = ir.ConcatKernel(
  396. name=None,
  397. layout=ir.FixedLayout(device, dtype, new_size),
  398. inputs=[],
  399. )
  400. kernel_tensor = ir.TensorBox.create(kernel)
  401. for i in range(len(inputs)):
  402. dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i])
  403. src = op(*inputs[i], layout=dst.get_layout()).data.data
  404. assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer))
  405. src.layout = ir.AliasedLayout(dst)
  406. kernel.inputs.append(src)
  407. kernel.name = V.graph.register_buffer(kernel)
  408. kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs)
  409. return kernel_tensor
  410. _cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
  411. @register_lowering_pattern(
  412. CallFunction(
  413. aten.cat,
  414. [
  415. _cat_1,
  416. CallFunction(
  417. aten.slice,
  418. CallFunction(aten.slice, _cat_1, 0, 0, 9223372036854775807),
  419. 1,
  420. 0,
  421. KeywordArg("size"),
  422. ),
  423. ],
  424. 1,
  425. )
  426. )
  427. def cat_slice_cat(match, cat_input, size, dim=1):
  428. """
  429. This is an example of a more complex pattern where cat_1 is used
  430. multiple times inside the pattern. We fold 2 calls to cat into one.
  431. Matches:
  432. cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
  433. slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
  434. slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
  435. cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
  436. Rewrite to:
  437. slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
  438. cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
  439. """
  440. first, *rest = cat_input
  441. if V.graph.sizevars.maybe_guard_leq(size, first.get_size()[dim]):
  442. # fold 2 cats into 1 cat
  443. return L[aten.cat](
  444. [
  445. first,
  446. *rest,
  447. L[aten.slice](first, dim, 0, size),
  448. ],
  449. dim,
  450. )
  451. else:
  452. # don't expect to hit this case, just fall back
  453. tmp = L[aten.cat](cat_input, dim)
  454. return L[aten.cat](
  455. [
  456. tmp,
  457. L[aten.slice](tmp, dim, 0, size),
  458. ],
  459. dim,
  460. )
  461. @register_replacement_pattern(
  462. CallFunction(
  463. aten.add,
  464. CallFunction(aten.mm, Arg(), Arg()),
  465. KeywordArg("added"),
  466. ),
  467. pass_number=2,
  468. )
  469. @register_replacement_pattern(
  470. CallFunction(
  471. aten.add,
  472. KeywordArg("added"),
  473. CallFunction(aten.mm, Arg(), Arg()),
  474. ),
  475. pass_number=2,
  476. )
  477. def addmm(mat1, mat2, added):
  478. return aten.addmm(added, mat1, mat2)
  479. # This slows things down:
  480. """
  481. @register_replacement_pattern(
  482. CallFunction(
  483. aten.add,
  484. CallFunction(aten.bmm, Arg(), Arg()),
  485. KeywordArg("added"),
  486. ),
  487. pass_number=3
  488. )
  489. @register_replacement_pattern(
  490. CallFunction(
  491. aten.add,
  492. KeywordArg("added"),
  493. CallFunction(aten.bmm, Arg(), Arg()),
  494. ),
  495. pass_number=3
  496. )
  497. def baddbmm(mat1, mat2, added):
  498. return aten.baddbmm(added, mat1, mat2)
  499. """