mkldnn.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import torch
  2. class MkldnnLinear(torch.jit.ScriptModule):
  3. def __init__(self, dense_module, dtype):
  4. super().__init__()
  5. self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
  6. if dense_module.bias is not None:
  7. # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
  8. # we use fp32 dtype.
  9. self.register_buffer('bias', dense_module.bias.to_mkldnn())
  10. else:
  11. # TODO: Remove this once ScriptModule supports registering None buffer
  12. self.register_buffer(
  13. 'bias',
  14. torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
  15. @torch.jit.script_method
  16. def __getstate__(self):
  17. return (self.weight.to_dense(), self.bias.to_dense(), self.training)
  18. @torch.jit.script_method
  19. def __setstate__(self, state):
  20. self.weight = state[0].to_mkldnn()
  21. self.bias = state[1].to_mkldnn()
  22. self.training = state[2]
  23. @torch.jit.script_method
  24. def forward(self, x):
  25. x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
  26. y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
  27. y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
  28. return y
  29. class _MkldnnConvNd(torch.jit.ScriptModule):
  30. """Common base of MkldnnConv1d and MkldnnConv2d"""
  31. __constants__ = ['stride', 'padding', 'dilation', 'groups']
  32. def __init__(self, dense_module):
  33. super().__init__()
  34. self.stride = dense_module.stride
  35. self.padding = dense_module.padding
  36. self.dilation = dense_module.dilation
  37. self.groups = dense_module.groups
  38. if dense_module.bias is not None:
  39. self.register_buffer('bias', dense_module.bias.to_mkldnn())
  40. else:
  41. # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
  42. # we use fp32 dtype.
  43. # TODO: Remove this once ScriptModule supports registering None buffer
  44. self.register_buffer(
  45. 'bias',
  46. torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
  47. @torch.jit.script_method
  48. def __getstate__(self):
  49. return (self.weight.to_dense(), self.bias.to_dense(), self.training)
  50. @torch.jit.script_method
  51. def forward(self, x):
  52. return torch.mkldnn_convolution(
  53. x,
  54. self.weight,
  55. self.bias,
  56. self.padding,
  57. self.stride,
  58. self.dilation,
  59. self.groups)
  60. class MkldnnConv1d(_MkldnnConvNd):
  61. def __init__(self, dense_module, dtype):
  62. super().__init__(dense_module)
  63. self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
  64. @torch.jit.script_method
  65. def __setstate__(self, state):
  66. self.weight = state[0].to_mkldnn()
  67. self.bias = state[1].to_mkldnn()
  68. self.training = state[2]
  69. class MkldnnConv2d(_MkldnnConvNd):
  70. def __init__(self, dense_module, dtype):
  71. super().__init__(dense_module)
  72. self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight(
  73. dense_module.weight.to_mkldnn(dtype),
  74. self.padding,
  75. self.stride,
  76. self.dilation,
  77. self.groups))
  78. @torch.jit.script_method
  79. def __setstate__(self, state):
  80. self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
  81. state[0].to_mkldnn(),
  82. self.padding,
  83. self.stride,
  84. self.dilation,
  85. self.groups)
  86. self.bias = state[1].to_mkldnn()
  87. self.training = state[2]
  88. class MkldnnConv3d(_MkldnnConvNd):
  89. def __init__(self, dense_module, dtype):
  90. super().__init__(dense_module)
  91. self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
  92. dense_module.weight.to_mkldnn(dtype),
  93. self.padding,
  94. self.stride,
  95. self.dilation,
  96. self.groups))
  97. @torch.jit.script_method
  98. def __setstate__(self, state):
  99. self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
  100. state[0].to_mkldnn(),
  101. self.padding,
  102. self.stride,
  103. self.dilation,
  104. self.groups)
  105. self.bias = state[1].to_mkldnn()
  106. self.training = state[2]
  107. class MkldnnBatchNorm(torch.jit.ScriptModule):
  108. __constants__ = ['exponential_average_factor', 'eps']
  109. def __init__(self, dense_module):
  110. super().__init__()
  111. assert(not dense_module.training)
  112. assert(dense_module.track_running_stats)
  113. assert(dense_module.affine)
  114. if dense_module.momentum is None:
  115. self.exponential_average_factor = 0.0
  116. else:
  117. self.exponential_average_factor = dense_module.momentum
  118. self.eps = dense_module.eps
  119. self.register_buffer('weight', dense_module.weight.to_mkldnn())
  120. self.register_buffer('bias', dense_module.bias.to_mkldnn())
  121. self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
  122. self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
  123. @torch.jit.script_method
  124. def __getstate__(self):
  125. weight = self.weight.to_dense()
  126. bias = self.bias.to_dense()
  127. running_mean = self.running_mean.to_dense()
  128. running_var = self.running_var.to_dense()
  129. return (weight, bias, running_mean, running_var, self.training)
  130. @torch.jit.script_method
  131. def __setstate__(self, state):
  132. self.weight = state[0].to_mkldnn()
  133. self.bias = state[1].to_mkldnn()
  134. self.running_mean = state[2].to_mkldnn()
  135. self.running_var = state[3].to_mkldnn()
  136. self.training = state[4]
  137. @torch.jit.script_method
  138. def forward(self, x):
  139. return torch.batch_norm(
  140. x,
  141. self.weight,
  142. self.bias,
  143. self.running_mean,
  144. self.running_var,
  145. False, # training
  146. self.exponential_average_factor,
  147. self.eps,
  148. False, # cuda_enabled
  149. )
  150. class MkldnnPrelu(torch.jit.ScriptModule):
  151. def __init__(self, dense_module, dtype):
  152. super().__init__()
  153. self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
  154. @torch.jit.script_method
  155. def __getstate__(self):
  156. return (self.weight.to_dense(), self.training)
  157. @torch.jit.script_method
  158. def __setstate__(self, state):
  159. self.weight = state[0].to_mkldnn()
  160. self.training = state[1]
  161. @torch.jit.script_method
  162. def forward(self, x):
  163. x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
  164. y_mkldnn = torch.prelu(x_mkldnn, self.weight)
  165. y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
  166. return y
  167. def to_mkldnn(module, dtype=torch.float):
  168. assert dtype in [torch.float, torch.bfloat16], \
  169. "MKLDNN only support float or bfloat16 path now"
  170. def m_fn(m, d):
  171. if isinstance(m, torch.nn.Linear):
  172. return MkldnnLinear(m, d)
  173. elif isinstance(m, torch.nn.Conv1d):
  174. return MkldnnConv1d(m, d)
  175. elif isinstance(m, torch.nn.Conv2d):
  176. return MkldnnConv2d(m, d)
  177. elif isinstance(m, torch.nn.Conv3d):
  178. return MkldnnConv3d(m, d)
  179. elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
  180. # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
  181. # so it doesn't need dtype argument.
  182. return MkldnnBatchNorm(m)
  183. elif isinstance(m, torch.nn.PReLU):
  184. return MkldnnPrelu(m, d)
  185. else:
  186. return m
  187. def m_fn_rec(m, d):
  188. new_m = m_fn(m, d)
  189. for name, sub_m in m.named_children():
  190. setattr(new_m, name, m_fn_rec(sub_m, d))
  191. return new_m
  192. return m_fn_rec(module, dtype)