123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- import torch
- class MkldnnLinear(torch.jit.ScriptModule):
- def __init__(self, dense_module, dtype):
- super().__init__()
- self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
- if dense_module.bias is not None:
- # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
- # we use fp32 dtype.
- self.register_buffer('bias', dense_module.bias.to_mkldnn())
- else:
- # TODO: Remove this once ScriptModule supports registering None buffer
- self.register_buffer(
- 'bias',
- torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
- @torch.jit.script_method
- def __getstate__(self):
- return (self.weight.to_dense(), self.bias.to_dense(), self.training)
- @torch.jit.script_method
- def __setstate__(self, state):
- self.weight = state[0].to_mkldnn()
- self.bias = state[1].to_mkldnn()
- self.training = state[2]
- @torch.jit.script_method
- def forward(self, x):
- x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
- y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
- y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
- return y
- class _MkldnnConvNd(torch.jit.ScriptModule):
- """Common base of MkldnnConv1d and MkldnnConv2d"""
- __constants__ = ['stride', 'padding', 'dilation', 'groups']
- def __init__(self, dense_module):
- super().__init__()
- self.stride = dense_module.stride
- self.padding = dense_module.padding
- self.dilation = dense_module.dilation
- self.groups = dense_module.groups
- if dense_module.bias is not None:
- self.register_buffer('bias', dense_module.bias.to_mkldnn())
- else:
- # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
- # we use fp32 dtype.
- # TODO: Remove this once ScriptModule supports registering None buffer
- self.register_buffer(
- 'bias',
- torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
- @torch.jit.script_method
- def __getstate__(self):
- return (self.weight.to_dense(), self.bias.to_dense(), self.training)
- @torch.jit.script_method
- def forward(self, x):
- return torch.mkldnn_convolution(
- x,
- self.weight,
- self.bias,
- self.padding,
- self.stride,
- self.dilation,
- self.groups)
- class MkldnnConv1d(_MkldnnConvNd):
- def __init__(self, dense_module, dtype):
- super().__init__(dense_module)
- self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
- @torch.jit.script_method
- def __setstate__(self, state):
- self.weight = state[0].to_mkldnn()
- self.bias = state[1].to_mkldnn()
- self.training = state[2]
- class MkldnnConv2d(_MkldnnConvNd):
- def __init__(self, dense_module, dtype):
- super().__init__(dense_module)
- self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight(
- dense_module.weight.to_mkldnn(dtype),
- self.padding,
- self.stride,
- self.dilation,
- self.groups))
- @torch.jit.script_method
- def __setstate__(self, state):
- self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
- state[0].to_mkldnn(),
- self.padding,
- self.stride,
- self.dilation,
- self.groups)
- self.bias = state[1].to_mkldnn()
- self.training = state[2]
- class MkldnnConv3d(_MkldnnConvNd):
- def __init__(self, dense_module, dtype):
- super().__init__(dense_module)
- self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
- dense_module.weight.to_mkldnn(dtype),
- self.padding,
- self.stride,
- self.dilation,
- self.groups))
- @torch.jit.script_method
- def __setstate__(self, state):
- self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
- state[0].to_mkldnn(),
- self.padding,
- self.stride,
- self.dilation,
- self.groups)
- self.bias = state[1].to_mkldnn()
- self.training = state[2]
- class MkldnnBatchNorm(torch.jit.ScriptModule):
- __constants__ = ['exponential_average_factor', 'eps']
- def __init__(self, dense_module):
- super().__init__()
- assert(not dense_module.training)
- assert(dense_module.track_running_stats)
- assert(dense_module.affine)
- if dense_module.momentum is None:
- self.exponential_average_factor = 0.0
- else:
- self.exponential_average_factor = dense_module.momentum
- self.eps = dense_module.eps
- self.register_buffer('weight', dense_module.weight.to_mkldnn())
- self.register_buffer('bias', dense_module.bias.to_mkldnn())
- self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
- self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
- @torch.jit.script_method
- def __getstate__(self):
- weight = self.weight.to_dense()
- bias = self.bias.to_dense()
- running_mean = self.running_mean.to_dense()
- running_var = self.running_var.to_dense()
- return (weight, bias, running_mean, running_var, self.training)
- @torch.jit.script_method
- def __setstate__(self, state):
- self.weight = state[0].to_mkldnn()
- self.bias = state[1].to_mkldnn()
- self.running_mean = state[2].to_mkldnn()
- self.running_var = state[3].to_mkldnn()
- self.training = state[4]
- @torch.jit.script_method
- def forward(self, x):
- return torch.batch_norm(
- x,
- self.weight,
- self.bias,
- self.running_mean,
- self.running_var,
- False, # training
- self.exponential_average_factor,
- self.eps,
- False, # cuda_enabled
- )
- class MkldnnPrelu(torch.jit.ScriptModule):
- def __init__(self, dense_module, dtype):
- super().__init__()
- self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
- @torch.jit.script_method
- def __getstate__(self):
- return (self.weight.to_dense(), self.training)
- @torch.jit.script_method
- def __setstate__(self, state):
- self.weight = state[0].to_mkldnn()
- self.training = state[1]
- @torch.jit.script_method
- def forward(self, x):
- x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
- y_mkldnn = torch.prelu(x_mkldnn, self.weight)
- y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
- return y
- def to_mkldnn(module, dtype=torch.float):
- assert dtype in [torch.float, torch.bfloat16], \
- "MKLDNN only support float or bfloat16 path now"
- def m_fn(m, d):
- if isinstance(m, torch.nn.Linear):
- return MkldnnLinear(m, d)
- elif isinstance(m, torch.nn.Conv1d):
- return MkldnnConv1d(m, d)
- elif isinstance(m, torch.nn.Conv2d):
- return MkldnnConv2d(m, d)
- elif isinstance(m, torch.nn.Conv3d):
- return MkldnnConv3d(m, d)
- elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
- # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
- # so it doesn't need dtype argument.
- return MkldnnBatchNorm(m)
- elif isinstance(m, torch.nn.PReLU):
- return MkldnnPrelu(m, d)
- else:
- return m
- def m_fn_rec(m, d):
- new_m = m_fn(m, d)
- for name, sub_m in m.named_children():
- setattr(new_m, name, m_fn_rec(sub_m, d))
- return new_m
- return m_fn_rec(module, dtype)
|