123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844 |
- import copy
- import itertools
- import operator
- from functools import reduce
- from typing import Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch._dynamo.utils import fake_mode_from_tensors
- from torch.fx.experimental.optimization import (
- matches_module_pattern,
- replace_node_module,
- )
- from torch.fx.experimental.symbolic_shapes import guard_int
- from torch.fx.passes.shape_prop import ShapeProp
- from torch.nn.modules.utils import _pair
- from . import config
- from .fx_utils import matches_module_function_pattern
- class UnaryAttr:
- def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
- self.op_name = op_name
- self.scalars_attr = scalars_attr if scalars_attr else []
- self.algorithm_attr = algorithm_attr if algorithm_attr else ""
- super().__init__()
- def __call__(self, unary_module: nn.Module):
- if type(unary_module) is nn.ReLU6:
- unary_module = nn.Hardtanh(min_val=0, max_val=6)
- assert all(hasattr(unary_module, item) for item in self.scalars_attr)
- scalars = [getattr(unary_module, item) for item in self.scalars_attr]
- algorithm = ""
- if self.algorithm_attr:
- assert hasattr(unary_module, self.algorithm_attr)
- algorithm = getattr(unary_module, self.algorithm_attr)
- return self.op_name, scalars, algorithm
- def is_bfloat16_module(m):
- weight_is_bf16 = m.weight.dtype == torch.bfloat16
- bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16
- return weight_is_bf16 and bias_is_bf16
- def is_group_depthwise_conv_transpose(m):
- return (
- type(m) in [nn.ConvTranspose2d] and m.groups > 1 and m.groups == m.in_channels
- )
- def check_node_kind(current_node, modules, node_kind):
- if not isinstance(current_node, torch.fx.Node):
- return False
- if current_node.op != "call_module":
- return False
- if not isinstance(current_node.target, str):
- return False
- if current_node.target not in modules:
- return False
- if type(modules[current_node.target]) is not node_kind:
- return False
- return True
- def check_node_is_binary(node):
- return (
- (node.op == "call_function" and node.target in [torch.add, torch.sub])
- or (
- node.op == "call_function"
- and node.target
- in [operator.add, operator.iadd, operator.sub, operator.isub]
- )
- or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"])
- )
- def check_binary_op_kwargs_is_default(node):
- # For binary op, we hope the kwargs values are the default value:
- # torch.sub(add)(input, other, *, alpha=1, out=None).
- if len(node.args) > 2:
- return False
- if len(node.kwargs) > 0:
- if "out" in node.kwargs and node.kwargs["out"] is not None:
- return False
- if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0:
- return False
- return True
- class ConvUnary2d(nn.Conv2d):
- def __init__(
- self,
- conv: nn.Module,
- unary: Optional[nn.Module],
- input_size: list,
- ):
- super().__init__(
- conv.in_channels,
- conv.out_channels,
- conv.kernel_size,
- conv.stride,
- conv.padding,
- conv.dilation,
- conv.groups,
- conv.bias is not None,
- conv.padding_mode,
- conv.weight.device,
- conv.weight.dtype,
- )
- self._update_module_params(conv, unary, input_size)
- def _update_module_params(self, conv, unary, input_size):
- self.__dict__ = copy.deepcopy(conv.__dict__)
- self.attr = "none"
- self.scalars = []
- self.algorithm = ""
- if unary is not None:
- self.attr, self.scalars, self.algorithm = unary_modules_map[
- unary.__class__
- ](unary)
- self.weight = torch.nn.Parameter(
- torch._C._nn.mkldnn_reorder_conv2d_weight(
- self.weight.to_mkldnn(),
- self.padding,
- self.stride,
- self.dilation,
- self.groups,
- tuple(guard_int(x) for x in input_size),
- ),
- requires_grad=self.weight.requires_grad,
- )
- def _conv_forward(self, input, weight, bias):
- if self.padding_mode != "zeros":
- return torch.ops.mkldnn._convolution_pointwise(
- F.pad(
- input, self._reversed_padding_repeated_twice, mode=self.padding_mode
- ),
- weight,
- bias,
- _pair(0),
- self.stride,
- self.dilation,
- self.groups,
- self.attr,
- self.scalars,
- self.algorithm,
- )
- return torch.ops.mkldnn._convolution_pointwise(
- input,
- weight,
- bias,
- self.padding,
- self.stride,
- self.dilation,
- self.groups,
- self.attr,
- self.scalars,
- self.algorithm,
- )
- def forward(self, input):
- return self._conv_forward(input, self.weight, self.bias)
- class ConvBinary2d(nn.Conv2d):
- def __init__(
- self,
- conv: nn.Module,
- binary_op_name: str,
- input_size: list,
- ):
- super().__init__(
- conv.in_channels,
- conv.out_channels,
- conv.kernel_size,
- conv.stride,
- conv.padding,
- conv.dilation,
- conv.groups,
- conv.bias is not None,
- conv.padding_mode,
- conv.weight.device,
- conv.weight.dtype,
- )
- self._update_module_params(conv, binary_op_name, input_size)
- def _update_module_params(self, conv, binary_op_name, input_size):
- self.__dict__ = copy.deepcopy(conv.__dict__)
- self.binary_attr = binary_op_name
- self.binary_alpha = None
- self.unary_attr = None
- self.unary_scalars = []
- self.unary_algorithm = None
- self.weight = torch.nn.Parameter(
- torch._C._nn.mkldnn_reorder_conv2d_weight(
- self.weight.to_mkldnn(),
- self.padding,
- self.stride,
- self.dilation,
- self.groups,
- tuple(guard_int(x) for x in input_size),
- ),
- requires_grad=self.weight.requires_grad,
- )
- def _update_unary_params(self, unary):
- self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[
- unary.__class__
- ](unary)
- def _conv_forward(self, input, other, weight, bias):
- if self.padding_mode != "zeros":
- return torch.ops.mkldnn._convolution_pointwise(
- F.pad(
- input, self._reversed_padding_repeated_twice, mode=self.padding_mode
- ),
- other,
- weight,
- bias,
- _pair(0),
- self.stride,
- self.dilation,
- self.groups,
- self.binary_attr,
- self.binary_alpha,
- self.unary_attr,
- self.unary_scalars,
- self.unary_algorithm,
- )
- return torch.ops.mkldnn._convolution_pointwise(
- input,
- other,
- weight,
- bias,
- self.padding,
- self.stride,
- self.dilation,
- self.groups,
- self.binary_attr,
- self.binary_alpha,
- self.unary_attr,
- self.unary_scalars,
- self.unary_algorithm,
- )
- def forward(self, input, other):
- return self._conv_forward(input, other, self.weight, self.bias)
- class PackedLinear(nn.Linear):
- def __init__(self, linear: nn.Module, input_size: list):
- super().__init__(
- linear.in_features,
- linear.out_features,
- linear.bias is not None,
- linear.weight.device,
- linear.weight.dtype,
- )
- self._update_module_params(linear, input_size)
- def _update_module_params(self, linear, input_size):
- self.__dict__ = copy.deepcopy(linear.__dict__)
- self.batch_size = reduce(lambda x, y: x * y, input_size[:-1])
- self.packed_weight = torch.nn.Parameter(
- torch.ops.mkl._mkl_reorder_linear_weight(
- self.weight.to_mkldnn(), self.batch_size
- ),
- requires_grad=self.weight.requires_grad,
- )
- def forward(self, input):
- y = torch.ops.mkl._mkl_linear(
- input, self.packed_weight, self.weight, self.bias, self.batch_size
- )
- return y
- class LinearUnary(nn.Linear):
- def __init__(
- self,
- linear: nn.Module,
- unary: nn.Module,
- ):
- super().__init__(
- linear.in_features,
- linear.out_features,
- linear.bias is not None,
- linear.weight.device,
- linear.weight.dtype,
- )
- self._update_module_params(linear, unary)
- def _update_module_params(self, linear, unary):
- self.__dict__ = copy.deepcopy(linear.__dict__)
- self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
- unary
- )
- def forward(self, input):
- y = torch.ops.mkldnn._linear_pointwise(
- input, self.weight, self.bias, self.attr, self.scalars, self.algorithm
- )
- return y
- class LinearBinary(nn.Linear):
- def __init__(self, linear: nn.Module, binary_op_name: str):
- super().__init__(
- linear.in_features,
- linear.out_features,
- linear.bias is not None,
- linear.weight.device,
- linear.weight.dtype,
- )
- self._update_module_params(linear, binary_op_name)
- def _update_module_params(self, linear, binary_op_name):
- self.__dict__ = copy.deepcopy(linear.__dict__)
- self.attr = binary_op_name
- def forward(self, input, other):
- y = torch.ops.mkldnn._linear_pointwise(
- input, other, self.weight, self.bias, self.attr
- )
- return y
- class ConvTransposeUnary2d(nn.ConvTranspose2d):
- def __init__(
- self,
- conv_transpose: nn.Module,
- unary: Optional[nn.Module],
- input_size: list,
- ):
- super().__init__(
- conv_transpose.in_channels,
- conv_transpose.out_channels,
- conv_transpose.kernel_size,
- conv_transpose.stride,
- conv_transpose.padding,
- conv_transpose.output_padding,
- conv_transpose.groups,
- conv_transpose.bias is not None,
- conv_transpose.dilation,
- conv_transpose.padding_mode,
- conv_transpose.weight.device,
- conv_transpose.weight.dtype,
- )
- self._update_module_params(conv_transpose, unary, input_size)
- def _update_module_params(self, conv_transpose, unary, input_size):
- self.__dict__ = copy.deepcopy(conv_transpose.__dict__)
- self.attr, self.scalars, self.algorithm = (
- unary_modules_map[unary.__class__](unary) if unary else ("none", [], "")
- )
- packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
- self.weight.to_mkldnn(),
- self.padding,
- self.output_padding,
- self.stride,
- self.dilation,
- self.groups,
- input_size,
- )
- self.weight = torch.nn.Parameter(
- packed_weight,
- requires_grad=self.weight.requires_grad,
- )
- def _conv_transpose_forward(self, input, weight, bias):
- if self.padding_mode != "zeros":
- return torch.ops.mkldnn._convolution_transpose_pointwise(
- F.pad(
- input, self._reversed_padding_repeated_twice, mode=self.padding_mode
- ),
- weight,
- bias,
- _pair(0),
- self.output_padding,
- self.stride,
- self.dilation,
- self.groups,
- self.attr,
- self.scalars,
- self.algorithm,
- )
- return torch.ops.mkldnn._convolution_transpose_pointwise(
- input,
- weight,
- bias,
- self.padding,
- self.output_padding,
- self.stride,
- self.dilation,
- self.groups,
- self.attr,
- self.scalars,
- self.algorithm,
- )
- def forward(self, input):
- return self._conv_transpose_forward(input, self.weight, self.bias)
- def packed_conv_eval(conv: nn.Module, input_size: list):
- assert not (conv.training), "Fusion only for eval!"
- return ConvUnary2d(
- conv,
- None,
- input_size,
- )
- def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: list):
- assert not (conv_transpose.training), "Fusion only for eval!"
- return ConvTransposeUnary2d(
- conv_transpose,
- None,
- input_size,
- )
- def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list):
- assert not (conv.training), "Fusion only for eval!"
- return ConvUnary2d(
- conv,
- unary,
- input_size,
- )
- def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list):
- assert not (conv.training), "Fusion only for eval!"
- return ConvBinary2d(
- conv,
- binary_op_name,
- input_size,
- )
- def fused_conv_binary_unary_eval(
- conv_binary: nn.Module, unary: nn.Module, input_size: list
- ):
- assert not (conv_binary.training), "Fusion only for eval!"
- # reuse origin conv module, and just update its' unary attr.
- conv_binary._update_unary_params(unary)
- return conv_binary
- def packed_linear_eval(linear: nn.Module, input_size: list):
- assert not (linear.training), "Fusion only for eval!"
- return PackedLinear(linear, input_size)
- def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list):
- assert not (linear.training), "Fusion only for eval!"
- return LinearUnary(
- linear,
- unary,
- )
- def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):
- assert not (linear.training), "Fusion only for eval!"
- linear_binary = LinearBinary(
- linear,
- attr,
- )
- return linear_binary
- def fused_conv_transpose_unary_eval(
- conv_transpose: nn.Module, unary: nn.Module, input_size: list
- ):
- assert not (conv_transpose.training), "Fusion only for eval!"
- return ConvTransposeUnary2d(
- conv_transpose,
- unary,
- input_size,
- )
- def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
- is_cpu = all(
- example_input.device == torch.device("cpu")
- for example_input in example_inputs
- if isinstance(example_input, torch.Tensor)
- )
- # make sure the autograd is disabled.
- if torch.is_grad_enabled():
- return gm
- if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
- return gm
- if not is_cpu:
- return gm
- # For binary fusion, we need to check inputs info to make sure
- # the binary inputs have same tensor info(device, dtype, and layout).
- fake_mode = fake_mode_from_tensors(example_inputs)
- ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
- gm = fuse_unary(gm)
- gm = fuse_binary(gm)
- # why re-run fuse_unary? we want to enable conv+binary+unary fusion,
- # such as conv+add+relu for vision model.
- gm = fuse_unary(gm)
- if config.cpp.weight_prepack:
- gm = pack_module(gm)
- return gm
- def create_unary_module(node: torch.fx.node):
- assert (
- node.op == "call_function" or node.op == "call_method"
- ), "The current node should be a function/method node"
- unary_map = {
- F.relu: nn.ReLU,
- F.sigmoid: nn.Sigmoid,
- F.tanh: nn.Tanh,
- F.hardswish: nn.Hardswish,
- F.leaky_relu: nn.LeakyReLU,
- F.hardtanh: nn.Hardtanh,
- F.gelu: nn.GELU,
- F.relu6: nn.ReLU6,
- F.silu: nn.SiLU,
- F.hardsigmoid: nn.Hardsigmoid,
- torch.relu: nn.ReLU,
- torch.sigmoid: nn.Sigmoid,
- torch.tanh: nn.Tanh,
- "relu": nn.ReLU,
- "sigmoid": nn.Sigmoid,
- "tanh": nn.Tanh,
- }
- return unary_map[node.target](*(node.args[1:]), **(node.kwargs))
- def fuse_unary(gm: torch.fx.GraphModule):
- modules = dict(gm.named_modules())
- for unary_op, (
- computation_module,
- fuse_func,
- ) in itertools.product(unary_ops, computation_op_unary_op_fusion_map.items()):
- pattern = (computation_module, unary_op)
- for node in gm.graph.nodes:
- if matches_module_pattern(
- pattern, node, modules
- ) or matches_module_function_pattern(pattern, node, modules):
- if (
- len(node.args[0].users) > 1
- ): # Output of computation_node is used by other nodes
- continue
- computation_node = modules[node.args[0].target]
- if node.op == "call_function" or node.op == "call_method":
- # make sure unary function's inputs only one fx.node(others should be constant value).
- if any(isinstance(v, torch.fx.Node) for v in node.args[1:]) or any(
- isinstance(v, torch.fx.Node) for _, v in node.kwargs.items()
- ):
- continue
- unary_node = create_unary_module(node)
- unary_node.eval()
- else:
- unary_node = modules[node.target]
- eval_mode = all(not n.training for n in [computation_node, unary_node])
- if not eval_mode:
- continue
- # TODO: support padding str input("valid", "same").
- if type(computation_node) in [nn.Conv2d] and isinstance(
- computation_node.padding, str
- ):
- continue
- # TODO: support more conv+binary+unary fusion.
- if type(computation_node) in [ConvBinary2d] and type(
- unary_node
- ) not in [nn.ReLU]:
- continue
- # only fuse for linear when the dtype is bf16
- if type(computation_node) in [nn.Linear] and not is_bfloat16_module(
- computation_node
- ):
- continue
- # TODO: remove this when group depthwise ConvTranspose is supported
- if is_group_depthwise_conv_transpose(computation_node):
- continue
- computation_node_input_size = (
- node.args[0].args[0].meta.get("tensor_meta").shape
- )
- fused_module = fuse_func(
- computation_node, unary_node, computation_node_input_size
- )
- replace_node_module(node.args[0], modules, fused_module)
- node.replace_all_uses_with(node.args[0])
- gm.graph.erase_node(node)
- gm.graph.lint()
- gm.recompile()
- return gm
- def replace_and_fuse_for_binary(
- computation_node, node, fuse_func, attr, modules, index_node, index_pointwise
- ):
- computation_node_input_size = (
- node.args[index_node].args[0].meta.get("tensor_meta").shape
- )
- fused_module = fuse_func(computation_node, attr, computation_node_input_size)
- replace_node_module(node.args[index_node], modules, fused_module)
- node.args[index_node].args = node.args[index_node].args + (
- node.args[index_pointwise],
- )
- node.replace_all_uses_with(node.args[index_node])
- def binary_inputs_meta_is_same(binary_node):
- tensor0_meta = binary_node.args[0].meta.get("tensor_meta")
- tensor1_meta = binary_node.args[1].meta.get("tensor_meta")
- if not tensor0_meta or not tensor1_meta:
- return False
- if (
- tensor0_meta.shape != tensor1_meta.shape
- or tensor0_meta.stride != tensor1_meta.stride
- or tensor0_meta.dtype != tensor1_meta.dtype
- ):
- return False
- return True
- def fuse_binary(gm: torch.fx.GraphModule):
- modules = dict(gm.named_modules())
- for node in gm.graph.nodes:
- if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node):
- for node_kind, fuse_func in computation_op_binary_op_fusion_map.items():
- if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
- node.args[1], torch.fx.Node
- ):
- continue
- if not binary_inputs_meta_is_same(node):
- continue
- attr = binary_attr[node.target]
- index_list = supported_index_list[attr]
- for index_dict in index_list:
- index_node = index_dict["index_computation"]
- index_pointwise = index_dict["index_pointwise"]
- if check_node_kind(node.args[index_node], modules, node_kind):
- if len(node.args[index_node].users) > 1:
- continue
- computation_node = modules[node.args[index_node].target]
- if computation_node.training:
- continue
- # TODO: support padding str input("valid", "same").
- if type(computation_node) in [nn.Conv2d] and isinstance(
- computation_node.padding, str
- ):
- continue
- # only fuse for linear when the dtype is bf16
- if type(computation_node) in [
- nn.Linear
- ] and not is_bfloat16_module(computation_node):
- continue
- replace_and_fuse_for_binary(
- computation_node,
- node,
- fuse_func,
- attr if attr != "iadd" else "add",
- modules,
- index_node,
- index_pointwise,
- )
- # Make sure the fused node is post node of node's inputs nodes.
- node.append(node.args[index_node])
- gm.graph.erase_node(node)
- break
- gm.graph.lint()
- gm.recompile()
- return gm
- def convert_outplace_to_inplace(gm: torch.fx.GraphModule):
- if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
- return gm
- # This function is about replace outplace with inplace for better performance(external call),
- # which happen after AOTAutograd.
- for node in gm.graph.nodes:
- if node.op == "call_function" and node.target in [
- torch.ops.mkldnn._convolution_pointwise.binary
- ]:
- # args[0] and args[1] is _convolution_pointwise.binary's input,
- # need to check whether args[1] can be written or not.
- if node.args[1].op in ["placeholder", "output"]:
- continue
- # TODO: node.args[1].users > 1, but node.args[1] never be used after current node.
- if len(node.args[1].users) > 1:
- continue
- if node.args[1] == node.args[0]:
- continue
- binary_attr = node.args[8]
- unary_attr = node.args[10]
- if binary_attr != "add" or unary_attr not in ["", "relu"]:
- continue
- node.target = torch.ops.mkldnn._convolution_pointwise_.binary
- gm.graph.lint()
- gm.recompile()
- return gm
- def pack_module(gm: torch.fx.GraphModule):
- modules = dict(gm.named_modules())
- for node in gm.graph.nodes:
- if node.op == "call_module":
- assert isinstance(node.target, str)
- cur_module = modules[node.target]
- if type(cur_module) in computation_op_packed_map:
- if cur_module.training:
- continue
- computation_node_input_meta = node.args[0].meta.get("tensor_meta")
- if computation_node_input_meta.dtype != torch.float32:
- continue
- if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl:
- continue
- computation_node_input_size = computation_node_input_meta.shape
- if (
- type(cur_module) in [torch.nn.Linear]
- and len(computation_node_input_size) < 2
- ):
- continue
- if type(cur_module) in [nn.Conv2d] and isinstance(
- cur_module.padding, str
- ):
- continue
- # TODO: remove this when group depthwise ConvTranspose is supported
- if is_group_depthwise_conv_transpose(cur_module):
- continue
- new_module = computation_op_packed_map[type(cur_module)](
- cur_module, computation_node_input_size
- )
- assert isinstance(new_module, nn.Module)
- replace_node_module(node, modules, new_module)
- gm.graph.lint()
- gm.recompile()
- return gm
- computation_op_unary_op_fusion_map = {
- nn.Conv2d: fused_conv_unary_eval,
- nn.Linear: fused_linear_unary_eval,
- ConvBinary2d: fused_conv_binary_unary_eval,
- nn.ConvTranspose2d: fused_conv_transpose_unary_eval,
- }
- unary_modules_map = {
- nn.ReLU: UnaryAttr("relu"),
- nn.Sigmoid: UnaryAttr("sigmoid"),
- nn.Tanh: UnaryAttr("tanh"),
- nn.Hardswish: UnaryAttr("hardswish"),
- nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]),
- nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
- nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"),
- nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
- nn.SiLU: UnaryAttr("swish"),
- nn.Hardsigmoid: UnaryAttr("hardsigmoid"),
- }
- unary_ops = [
- # modules
- nn.ReLU,
- nn.Sigmoid,
- nn.Tanh,
- nn.Hardswish,
- nn.LeakyReLU,
- nn.Hardtanh,
- nn.GELU,
- nn.ReLU6,
- nn.SiLU,
- nn.Hardsigmoid,
- # functional
- F.relu,
- F.sigmoid,
- F.tanh,
- F.hardswish,
- F.leaky_relu,
- F.hardtanh,
- F.gelu,
- F.relu6,
- F.silu,
- F.hardsigmoid,
- torch.relu,
- torch.sigmoid,
- torch.tanh,
- # methods (torch.Tensor.xxx)
- "relu",
- "sigmoid",
- "tanh",
- ]
- binary_attr = {
- torch.add: "add", # node.op == "call_function"
- "add": "add", # node.op == "call_method"
- "add_": "iadd", # node.op == "call_method"
- operator.add: "add", # node.op == "call_function"
- operator.iadd: "iadd", # node.op == "call_function"
- torch.sub: "sub", # node.op == "call_function"
- "sub": "sub", # node.op == "call_method"
- "sub_": "sub", # node.op == "call_method"
- operator.sub: "sub", # node.op == "call_function"
- operator.isub: "sub", # node.op == "call_function"
- }
- computation_op_binary_op_fusion_map = {
- nn.Conv2d: fused_conv_binary_eval,
- nn.Linear: fused_linear_binary_eval,
- }
- computation_op_packed_map = {
- nn.Linear: packed_linear_eval,
- nn.Conv2d: packed_conv_eval,
- nn.ConvTranspose2d: packed_conv_transpose_eval,
- }
- # For add: we support conv/linear + other and other + conv
- # For sub/add_/sub_, we only support conv/linear - other
- # or conv/linear +(-)= other
- supported_index_list = {
- "add": [
- {"index_computation": 0, "index_pointwise": 1},
- {"index_computation": 1, "index_pointwise": 0},
- ],
- "iadd": [{"index_computation": 0, "index_pointwise": 1}],
- "sub": [{"index_computation": 0, "index_pointwise": 1}],
- }
|