mkldnn.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. import copy
  2. import itertools
  3. import operator
  4. from functools import reduce
  5. from typing import Optional
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from torch._dynamo.utils import fake_mode_from_tensors
  10. from torch.fx.experimental.optimization import (
  11. matches_module_pattern,
  12. replace_node_module,
  13. )
  14. from torch.fx.experimental.symbolic_shapes import guard_int
  15. from torch.fx.passes.shape_prop import ShapeProp
  16. from torch.nn.modules.utils import _pair
  17. from . import config
  18. from .fx_utils import matches_module_function_pattern
  19. class UnaryAttr:
  20. def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
  21. self.op_name = op_name
  22. self.scalars_attr = scalars_attr if scalars_attr else []
  23. self.algorithm_attr = algorithm_attr if algorithm_attr else ""
  24. super().__init__()
  25. def __call__(self, unary_module: nn.Module):
  26. if type(unary_module) is nn.ReLU6:
  27. unary_module = nn.Hardtanh(min_val=0, max_val=6)
  28. assert all(hasattr(unary_module, item) for item in self.scalars_attr)
  29. scalars = [getattr(unary_module, item) for item in self.scalars_attr]
  30. algorithm = ""
  31. if self.algorithm_attr:
  32. assert hasattr(unary_module, self.algorithm_attr)
  33. algorithm = getattr(unary_module, self.algorithm_attr)
  34. return self.op_name, scalars, algorithm
  35. def is_bfloat16_module(m):
  36. weight_is_bf16 = m.weight.dtype == torch.bfloat16
  37. bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16
  38. return weight_is_bf16 and bias_is_bf16
  39. def is_group_depthwise_conv_transpose(m):
  40. return (
  41. type(m) in [nn.ConvTranspose2d] and m.groups > 1 and m.groups == m.in_channels
  42. )
  43. def check_node_kind(current_node, modules, node_kind):
  44. if not isinstance(current_node, torch.fx.Node):
  45. return False
  46. if current_node.op != "call_module":
  47. return False
  48. if not isinstance(current_node.target, str):
  49. return False
  50. if current_node.target not in modules:
  51. return False
  52. if type(modules[current_node.target]) is not node_kind:
  53. return False
  54. return True
  55. def check_node_is_binary(node):
  56. return (
  57. (node.op == "call_function" and node.target in [torch.add, torch.sub])
  58. or (
  59. node.op == "call_function"
  60. and node.target
  61. in [operator.add, operator.iadd, operator.sub, operator.isub]
  62. )
  63. or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"])
  64. )
  65. def check_binary_op_kwargs_is_default(node):
  66. # For binary op, we hope the kwargs values are the default value:
  67. # torch.sub(add)(input, other, *, alpha=1, out=None).
  68. if len(node.args) > 2:
  69. return False
  70. if len(node.kwargs) > 0:
  71. if "out" in node.kwargs and node.kwargs["out"] is not None:
  72. return False
  73. if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0:
  74. return False
  75. return True
  76. class ConvUnary2d(nn.Conv2d):
  77. def __init__(
  78. self,
  79. conv: nn.Module,
  80. unary: Optional[nn.Module],
  81. input_size: list,
  82. ):
  83. super().__init__(
  84. conv.in_channels,
  85. conv.out_channels,
  86. conv.kernel_size,
  87. conv.stride,
  88. conv.padding,
  89. conv.dilation,
  90. conv.groups,
  91. conv.bias is not None,
  92. conv.padding_mode,
  93. conv.weight.device,
  94. conv.weight.dtype,
  95. )
  96. self._update_module_params(conv, unary, input_size)
  97. def _update_module_params(self, conv, unary, input_size):
  98. self.__dict__ = copy.deepcopy(conv.__dict__)
  99. self.attr = "none"
  100. self.scalars = []
  101. self.algorithm = ""
  102. if unary is not None:
  103. self.attr, self.scalars, self.algorithm = unary_modules_map[
  104. unary.__class__
  105. ](unary)
  106. self.weight = torch.nn.Parameter(
  107. torch._C._nn.mkldnn_reorder_conv2d_weight(
  108. self.weight.to_mkldnn(),
  109. self.padding,
  110. self.stride,
  111. self.dilation,
  112. self.groups,
  113. tuple(guard_int(x) for x in input_size),
  114. ),
  115. requires_grad=self.weight.requires_grad,
  116. )
  117. def _conv_forward(self, input, weight, bias):
  118. if self.padding_mode != "zeros":
  119. return torch.ops.mkldnn._convolution_pointwise(
  120. F.pad(
  121. input, self._reversed_padding_repeated_twice, mode=self.padding_mode
  122. ),
  123. weight,
  124. bias,
  125. _pair(0),
  126. self.stride,
  127. self.dilation,
  128. self.groups,
  129. self.attr,
  130. self.scalars,
  131. self.algorithm,
  132. )
  133. return torch.ops.mkldnn._convolution_pointwise(
  134. input,
  135. weight,
  136. bias,
  137. self.padding,
  138. self.stride,
  139. self.dilation,
  140. self.groups,
  141. self.attr,
  142. self.scalars,
  143. self.algorithm,
  144. )
  145. def forward(self, input):
  146. return self._conv_forward(input, self.weight, self.bias)
  147. class ConvBinary2d(nn.Conv2d):
  148. def __init__(
  149. self,
  150. conv: nn.Module,
  151. binary_op_name: str,
  152. input_size: list,
  153. ):
  154. super().__init__(
  155. conv.in_channels,
  156. conv.out_channels,
  157. conv.kernel_size,
  158. conv.stride,
  159. conv.padding,
  160. conv.dilation,
  161. conv.groups,
  162. conv.bias is not None,
  163. conv.padding_mode,
  164. conv.weight.device,
  165. conv.weight.dtype,
  166. )
  167. self._update_module_params(conv, binary_op_name, input_size)
  168. def _update_module_params(self, conv, binary_op_name, input_size):
  169. self.__dict__ = copy.deepcopy(conv.__dict__)
  170. self.binary_attr = binary_op_name
  171. self.binary_alpha = None
  172. self.unary_attr = None
  173. self.unary_scalars = []
  174. self.unary_algorithm = None
  175. self.weight = torch.nn.Parameter(
  176. torch._C._nn.mkldnn_reorder_conv2d_weight(
  177. self.weight.to_mkldnn(),
  178. self.padding,
  179. self.stride,
  180. self.dilation,
  181. self.groups,
  182. tuple(guard_int(x) for x in input_size),
  183. ),
  184. requires_grad=self.weight.requires_grad,
  185. )
  186. def _update_unary_params(self, unary):
  187. self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[
  188. unary.__class__
  189. ](unary)
  190. def _conv_forward(self, input, other, weight, bias):
  191. if self.padding_mode != "zeros":
  192. return torch.ops.mkldnn._convolution_pointwise(
  193. F.pad(
  194. input, self._reversed_padding_repeated_twice, mode=self.padding_mode
  195. ),
  196. other,
  197. weight,
  198. bias,
  199. _pair(0),
  200. self.stride,
  201. self.dilation,
  202. self.groups,
  203. self.binary_attr,
  204. self.binary_alpha,
  205. self.unary_attr,
  206. self.unary_scalars,
  207. self.unary_algorithm,
  208. )
  209. return torch.ops.mkldnn._convolution_pointwise(
  210. input,
  211. other,
  212. weight,
  213. bias,
  214. self.padding,
  215. self.stride,
  216. self.dilation,
  217. self.groups,
  218. self.binary_attr,
  219. self.binary_alpha,
  220. self.unary_attr,
  221. self.unary_scalars,
  222. self.unary_algorithm,
  223. )
  224. def forward(self, input, other):
  225. return self._conv_forward(input, other, self.weight, self.bias)
  226. class PackedLinear(nn.Linear):
  227. def __init__(self, linear: nn.Module, input_size: list):
  228. super().__init__(
  229. linear.in_features,
  230. linear.out_features,
  231. linear.bias is not None,
  232. linear.weight.device,
  233. linear.weight.dtype,
  234. )
  235. self._update_module_params(linear, input_size)
  236. def _update_module_params(self, linear, input_size):
  237. self.__dict__ = copy.deepcopy(linear.__dict__)
  238. self.batch_size = reduce(lambda x, y: x * y, input_size[:-1])
  239. self.packed_weight = torch.nn.Parameter(
  240. torch.ops.mkl._mkl_reorder_linear_weight(
  241. self.weight.to_mkldnn(), self.batch_size
  242. ),
  243. requires_grad=self.weight.requires_grad,
  244. )
  245. def forward(self, input):
  246. y = torch.ops.mkl._mkl_linear(
  247. input, self.packed_weight, self.weight, self.bias, self.batch_size
  248. )
  249. return y
  250. class LinearUnary(nn.Linear):
  251. def __init__(
  252. self,
  253. linear: nn.Module,
  254. unary: nn.Module,
  255. ):
  256. super().__init__(
  257. linear.in_features,
  258. linear.out_features,
  259. linear.bias is not None,
  260. linear.weight.device,
  261. linear.weight.dtype,
  262. )
  263. self._update_module_params(linear, unary)
  264. def _update_module_params(self, linear, unary):
  265. self.__dict__ = copy.deepcopy(linear.__dict__)
  266. self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
  267. unary
  268. )
  269. def forward(self, input):
  270. y = torch.ops.mkldnn._linear_pointwise(
  271. input, self.weight, self.bias, self.attr, self.scalars, self.algorithm
  272. )
  273. return y
  274. class LinearBinary(nn.Linear):
  275. def __init__(self, linear: nn.Module, binary_op_name: str):
  276. super().__init__(
  277. linear.in_features,
  278. linear.out_features,
  279. linear.bias is not None,
  280. linear.weight.device,
  281. linear.weight.dtype,
  282. )
  283. self._update_module_params(linear, binary_op_name)
  284. def _update_module_params(self, linear, binary_op_name):
  285. self.__dict__ = copy.deepcopy(linear.__dict__)
  286. self.attr = binary_op_name
  287. def forward(self, input, other):
  288. y = torch.ops.mkldnn._linear_pointwise(
  289. input, other, self.weight, self.bias, self.attr
  290. )
  291. return y
  292. class ConvTransposeUnary2d(nn.ConvTranspose2d):
  293. def __init__(
  294. self,
  295. conv_transpose: nn.Module,
  296. unary: Optional[nn.Module],
  297. input_size: list,
  298. ):
  299. super().__init__(
  300. conv_transpose.in_channels,
  301. conv_transpose.out_channels,
  302. conv_transpose.kernel_size,
  303. conv_transpose.stride,
  304. conv_transpose.padding,
  305. conv_transpose.output_padding,
  306. conv_transpose.groups,
  307. conv_transpose.bias is not None,
  308. conv_transpose.dilation,
  309. conv_transpose.padding_mode,
  310. conv_transpose.weight.device,
  311. conv_transpose.weight.dtype,
  312. )
  313. self._update_module_params(conv_transpose, unary, input_size)
  314. def _update_module_params(self, conv_transpose, unary, input_size):
  315. self.__dict__ = copy.deepcopy(conv_transpose.__dict__)
  316. self.attr, self.scalars, self.algorithm = (
  317. unary_modules_map[unary.__class__](unary) if unary else ("none", [], "")
  318. )
  319. packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
  320. self.weight.to_mkldnn(),
  321. self.padding,
  322. self.output_padding,
  323. self.stride,
  324. self.dilation,
  325. self.groups,
  326. input_size,
  327. )
  328. self.weight = torch.nn.Parameter(
  329. packed_weight,
  330. requires_grad=self.weight.requires_grad,
  331. )
  332. def _conv_transpose_forward(self, input, weight, bias):
  333. if self.padding_mode != "zeros":
  334. return torch.ops.mkldnn._convolution_transpose_pointwise(
  335. F.pad(
  336. input, self._reversed_padding_repeated_twice, mode=self.padding_mode
  337. ),
  338. weight,
  339. bias,
  340. _pair(0),
  341. self.output_padding,
  342. self.stride,
  343. self.dilation,
  344. self.groups,
  345. self.attr,
  346. self.scalars,
  347. self.algorithm,
  348. )
  349. return torch.ops.mkldnn._convolution_transpose_pointwise(
  350. input,
  351. weight,
  352. bias,
  353. self.padding,
  354. self.output_padding,
  355. self.stride,
  356. self.dilation,
  357. self.groups,
  358. self.attr,
  359. self.scalars,
  360. self.algorithm,
  361. )
  362. def forward(self, input):
  363. return self._conv_transpose_forward(input, self.weight, self.bias)
  364. def packed_conv_eval(conv: nn.Module, input_size: list):
  365. assert not (conv.training), "Fusion only for eval!"
  366. return ConvUnary2d(
  367. conv,
  368. None,
  369. input_size,
  370. )
  371. def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: list):
  372. assert not (conv_transpose.training), "Fusion only for eval!"
  373. return ConvTransposeUnary2d(
  374. conv_transpose,
  375. None,
  376. input_size,
  377. )
  378. def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list):
  379. assert not (conv.training), "Fusion only for eval!"
  380. return ConvUnary2d(
  381. conv,
  382. unary,
  383. input_size,
  384. )
  385. def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list):
  386. assert not (conv.training), "Fusion only for eval!"
  387. return ConvBinary2d(
  388. conv,
  389. binary_op_name,
  390. input_size,
  391. )
  392. def fused_conv_binary_unary_eval(
  393. conv_binary: nn.Module, unary: nn.Module, input_size: list
  394. ):
  395. assert not (conv_binary.training), "Fusion only for eval!"
  396. # reuse origin conv module, and just update its' unary attr.
  397. conv_binary._update_unary_params(unary)
  398. return conv_binary
  399. def packed_linear_eval(linear: nn.Module, input_size: list):
  400. assert not (linear.training), "Fusion only for eval!"
  401. return PackedLinear(linear, input_size)
  402. def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list):
  403. assert not (linear.training), "Fusion only for eval!"
  404. return LinearUnary(
  405. linear,
  406. unary,
  407. )
  408. def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):
  409. assert not (linear.training), "Fusion only for eval!"
  410. linear_binary = LinearBinary(
  411. linear,
  412. attr,
  413. )
  414. return linear_binary
  415. def fused_conv_transpose_unary_eval(
  416. conv_transpose: nn.Module, unary: nn.Module, input_size: list
  417. ):
  418. assert not (conv_transpose.training), "Fusion only for eval!"
  419. return ConvTransposeUnary2d(
  420. conv_transpose,
  421. unary,
  422. input_size,
  423. )
  424. def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
  425. is_cpu = all(
  426. example_input.device == torch.device("cpu")
  427. for example_input in example_inputs
  428. if isinstance(example_input, torch.Tensor)
  429. )
  430. # make sure the autograd is disabled.
  431. if torch.is_grad_enabled():
  432. return gm
  433. if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
  434. return gm
  435. if not is_cpu:
  436. return gm
  437. # For binary fusion, we need to check inputs info to make sure
  438. # the binary inputs have same tensor info(device, dtype, and layout).
  439. fake_mode = fake_mode_from_tensors(example_inputs)
  440. ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
  441. gm = fuse_unary(gm)
  442. gm = fuse_binary(gm)
  443. # why re-run fuse_unary? we want to enable conv+binary+unary fusion,
  444. # such as conv+add+relu for vision model.
  445. gm = fuse_unary(gm)
  446. if config.cpp.weight_prepack:
  447. gm = pack_module(gm)
  448. return gm
  449. def create_unary_module(node: torch.fx.node):
  450. assert (
  451. node.op == "call_function" or node.op == "call_method"
  452. ), "The current node should be a function/method node"
  453. unary_map = {
  454. F.relu: nn.ReLU,
  455. F.sigmoid: nn.Sigmoid,
  456. F.tanh: nn.Tanh,
  457. F.hardswish: nn.Hardswish,
  458. F.leaky_relu: nn.LeakyReLU,
  459. F.hardtanh: nn.Hardtanh,
  460. F.gelu: nn.GELU,
  461. F.relu6: nn.ReLU6,
  462. F.silu: nn.SiLU,
  463. F.hardsigmoid: nn.Hardsigmoid,
  464. torch.relu: nn.ReLU,
  465. torch.sigmoid: nn.Sigmoid,
  466. torch.tanh: nn.Tanh,
  467. "relu": nn.ReLU,
  468. "sigmoid": nn.Sigmoid,
  469. "tanh": nn.Tanh,
  470. }
  471. return unary_map[node.target](*(node.args[1:]), **(node.kwargs))
  472. def fuse_unary(gm: torch.fx.GraphModule):
  473. modules = dict(gm.named_modules())
  474. for unary_op, (
  475. computation_module,
  476. fuse_func,
  477. ) in itertools.product(unary_ops, computation_op_unary_op_fusion_map.items()):
  478. pattern = (computation_module, unary_op)
  479. for node in gm.graph.nodes:
  480. if matches_module_pattern(
  481. pattern, node, modules
  482. ) or matches_module_function_pattern(pattern, node, modules):
  483. if (
  484. len(node.args[0].users) > 1
  485. ): # Output of computation_node is used by other nodes
  486. continue
  487. computation_node = modules[node.args[0].target]
  488. if node.op == "call_function" or node.op == "call_method":
  489. # make sure unary function's inputs only one fx.node(others should be constant value).
  490. if any(isinstance(v, torch.fx.Node) for v in node.args[1:]) or any(
  491. isinstance(v, torch.fx.Node) for _, v in node.kwargs.items()
  492. ):
  493. continue
  494. unary_node = create_unary_module(node)
  495. unary_node.eval()
  496. else:
  497. unary_node = modules[node.target]
  498. eval_mode = all(not n.training for n in [computation_node, unary_node])
  499. if not eval_mode:
  500. continue
  501. # TODO: support padding str input("valid", "same").
  502. if type(computation_node) in [nn.Conv2d] and isinstance(
  503. computation_node.padding, str
  504. ):
  505. continue
  506. # TODO: support more conv+binary+unary fusion.
  507. if type(computation_node) in [ConvBinary2d] and type(
  508. unary_node
  509. ) not in [nn.ReLU]:
  510. continue
  511. # only fuse for linear when the dtype is bf16
  512. if type(computation_node) in [nn.Linear] and not is_bfloat16_module(
  513. computation_node
  514. ):
  515. continue
  516. # TODO: remove this when group depthwise ConvTranspose is supported
  517. if is_group_depthwise_conv_transpose(computation_node):
  518. continue
  519. computation_node_input_size = (
  520. node.args[0].args[0].meta.get("tensor_meta").shape
  521. )
  522. fused_module = fuse_func(
  523. computation_node, unary_node, computation_node_input_size
  524. )
  525. replace_node_module(node.args[0], modules, fused_module)
  526. node.replace_all_uses_with(node.args[0])
  527. gm.graph.erase_node(node)
  528. gm.graph.lint()
  529. gm.recompile()
  530. return gm
  531. def replace_and_fuse_for_binary(
  532. computation_node, node, fuse_func, attr, modules, index_node, index_pointwise
  533. ):
  534. computation_node_input_size = (
  535. node.args[index_node].args[0].meta.get("tensor_meta").shape
  536. )
  537. fused_module = fuse_func(computation_node, attr, computation_node_input_size)
  538. replace_node_module(node.args[index_node], modules, fused_module)
  539. node.args[index_node].args = node.args[index_node].args + (
  540. node.args[index_pointwise],
  541. )
  542. node.replace_all_uses_with(node.args[index_node])
  543. def binary_inputs_meta_is_same(binary_node):
  544. tensor0_meta = binary_node.args[0].meta.get("tensor_meta")
  545. tensor1_meta = binary_node.args[1].meta.get("tensor_meta")
  546. if not tensor0_meta or not tensor1_meta:
  547. return False
  548. if (
  549. tensor0_meta.shape != tensor1_meta.shape
  550. or tensor0_meta.stride != tensor1_meta.stride
  551. or tensor0_meta.dtype != tensor1_meta.dtype
  552. ):
  553. return False
  554. return True
  555. def fuse_binary(gm: torch.fx.GraphModule):
  556. modules = dict(gm.named_modules())
  557. for node in gm.graph.nodes:
  558. if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node):
  559. for node_kind, fuse_func in computation_op_binary_op_fusion_map.items():
  560. if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
  561. node.args[1], torch.fx.Node
  562. ):
  563. continue
  564. if not binary_inputs_meta_is_same(node):
  565. continue
  566. attr = binary_attr[node.target]
  567. index_list = supported_index_list[attr]
  568. for index_dict in index_list:
  569. index_node = index_dict["index_computation"]
  570. index_pointwise = index_dict["index_pointwise"]
  571. if check_node_kind(node.args[index_node], modules, node_kind):
  572. if len(node.args[index_node].users) > 1:
  573. continue
  574. computation_node = modules[node.args[index_node].target]
  575. if computation_node.training:
  576. continue
  577. # TODO: support padding str input("valid", "same").
  578. if type(computation_node) in [nn.Conv2d] and isinstance(
  579. computation_node.padding, str
  580. ):
  581. continue
  582. # only fuse for linear when the dtype is bf16
  583. if type(computation_node) in [
  584. nn.Linear
  585. ] and not is_bfloat16_module(computation_node):
  586. continue
  587. replace_and_fuse_for_binary(
  588. computation_node,
  589. node,
  590. fuse_func,
  591. attr if attr != "iadd" else "add",
  592. modules,
  593. index_node,
  594. index_pointwise,
  595. )
  596. # Make sure the fused node is post node of node's inputs nodes.
  597. node.append(node.args[index_node])
  598. gm.graph.erase_node(node)
  599. break
  600. gm.graph.lint()
  601. gm.recompile()
  602. return gm
  603. def convert_outplace_to_inplace(gm: torch.fx.GraphModule):
  604. if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
  605. return gm
  606. # This function is about replace outplace with inplace for better performance(external call),
  607. # which happen after AOTAutograd.
  608. for node in gm.graph.nodes:
  609. if node.op == "call_function" and node.target in [
  610. torch.ops.mkldnn._convolution_pointwise.binary
  611. ]:
  612. # args[0] and args[1] is _convolution_pointwise.binary's input,
  613. # need to check whether args[1] can be written or not.
  614. if node.args[1].op in ["placeholder", "output"]:
  615. continue
  616. # TODO: node.args[1].users > 1, but node.args[1] never be used after current node.
  617. if len(node.args[1].users) > 1:
  618. continue
  619. if node.args[1] == node.args[0]:
  620. continue
  621. binary_attr = node.args[8]
  622. unary_attr = node.args[10]
  623. if binary_attr != "add" or unary_attr not in ["", "relu"]:
  624. continue
  625. node.target = torch.ops.mkldnn._convolution_pointwise_.binary
  626. gm.graph.lint()
  627. gm.recompile()
  628. return gm
  629. def pack_module(gm: torch.fx.GraphModule):
  630. modules = dict(gm.named_modules())
  631. for node in gm.graph.nodes:
  632. if node.op == "call_module":
  633. assert isinstance(node.target, str)
  634. cur_module = modules[node.target]
  635. if type(cur_module) in computation_op_packed_map:
  636. if cur_module.training:
  637. continue
  638. computation_node_input_meta = node.args[0].meta.get("tensor_meta")
  639. if computation_node_input_meta.dtype != torch.float32:
  640. continue
  641. if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl:
  642. continue
  643. computation_node_input_size = computation_node_input_meta.shape
  644. if (
  645. type(cur_module) in [torch.nn.Linear]
  646. and len(computation_node_input_size) < 2
  647. ):
  648. continue
  649. if type(cur_module) in [nn.Conv2d] and isinstance(
  650. cur_module.padding, str
  651. ):
  652. continue
  653. # TODO: remove this when group depthwise ConvTranspose is supported
  654. if is_group_depthwise_conv_transpose(cur_module):
  655. continue
  656. new_module = computation_op_packed_map[type(cur_module)](
  657. cur_module, computation_node_input_size
  658. )
  659. assert isinstance(new_module, nn.Module)
  660. replace_node_module(node, modules, new_module)
  661. gm.graph.lint()
  662. gm.recompile()
  663. return gm
  664. computation_op_unary_op_fusion_map = {
  665. nn.Conv2d: fused_conv_unary_eval,
  666. nn.Linear: fused_linear_unary_eval,
  667. ConvBinary2d: fused_conv_binary_unary_eval,
  668. nn.ConvTranspose2d: fused_conv_transpose_unary_eval,
  669. }
  670. unary_modules_map = {
  671. nn.ReLU: UnaryAttr("relu"),
  672. nn.Sigmoid: UnaryAttr("sigmoid"),
  673. nn.Tanh: UnaryAttr("tanh"),
  674. nn.Hardswish: UnaryAttr("hardswish"),
  675. nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]),
  676. nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
  677. nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"),
  678. nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
  679. nn.SiLU: UnaryAttr("swish"),
  680. nn.Hardsigmoid: UnaryAttr("hardsigmoid"),
  681. }
  682. unary_ops = [
  683. # modules
  684. nn.ReLU,
  685. nn.Sigmoid,
  686. nn.Tanh,
  687. nn.Hardswish,
  688. nn.LeakyReLU,
  689. nn.Hardtanh,
  690. nn.GELU,
  691. nn.ReLU6,
  692. nn.SiLU,
  693. nn.Hardsigmoid,
  694. # functional
  695. F.relu,
  696. F.sigmoid,
  697. F.tanh,
  698. F.hardswish,
  699. F.leaky_relu,
  700. F.hardtanh,
  701. F.gelu,
  702. F.relu6,
  703. F.silu,
  704. F.hardsigmoid,
  705. torch.relu,
  706. torch.sigmoid,
  707. torch.tanh,
  708. # methods (torch.Tensor.xxx)
  709. "relu",
  710. "sigmoid",
  711. "tanh",
  712. ]
  713. binary_attr = {
  714. torch.add: "add", # node.op == "call_function"
  715. "add": "add", # node.op == "call_method"
  716. "add_": "iadd", # node.op == "call_method"
  717. operator.add: "add", # node.op == "call_function"
  718. operator.iadd: "iadd", # node.op == "call_function"
  719. torch.sub: "sub", # node.op == "call_function"
  720. "sub": "sub", # node.op == "call_method"
  721. "sub_": "sub", # node.op == "call_method"
  722. operator.sub: "sub", # node.op == "call_function"
  723. operator.isub: "sub", # node.op == "call_function"
  724. }
  725. computation_op_binary_op_fusion_map = {
  726. nn.Conv2d: fused_conv_binary_eval,
  727. nn.Linear: fused_linear_binary_eval,
  728. }
  729. computation_op_packed_map = {
  730. nn.Linear: packed_linear_eval,
  731. nn.Conv2d: packed_conv_eval,
  732. nn.ConvTranspose2d: packed_conv_transpose_eval,
  733. }
  734. # For add: we support conv/linear + other and other + conv
  735. # For sub/add_/sub_, we only support conv/linear - other
  736. # or conv/linear +(-)= other
  737. supported_index_list = {
  738. "add": [
  739. {"index_computation": 0, "index_pointwise": 1},
  740. {"index_computation": 1, "index_pointwise": 0},
  741. ],
  742. "iadd": [{"index_computation": 0, "index_pointwise": 1}],
  743. "sub": [{"index_computation": 0, "index_pointwise": 1}],
  744. }