fusion.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import copy
  2. import torch
  3. def fuse_conv_bn_eval(conv, bn, transpose=False):
  4. assert(not (conv.training or bn.training)), "Fusion only for eval!"
  5. fused_conv = copy.deepcopy(conv)
  6. fused_conv.weight, fused_conv.bias = \
  7. fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
  8. bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
  9. return fused_conv
  10. def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
  11. if conv_b is None:
  12. conv_b = torch.zeros_like(bn_rm)
  13. if bn_w is None:
  14. bn_w = torch.ones_like(bn_rm)
  15. if bn_b is None:
  16. bn_b = torch.zeros_like(bn_rm)
  17. bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
  18. if transpose:
  19. shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
  20. else:
  21. shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
  22. fused_conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
  23. fused_conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
  24. return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)
  25. def fuse_linear_bn_eval(linear, bn):
  26. assert(not (linear.training or bn.training)), "Fusion only for eval!"
  27. fused_linear = copy.deepcopy(linear)
  28. fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
  29. fused_linear.weight, fused_linear.bias,
  30. bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
  31. return fused_linear
  32. def fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
  33. if linear_b is None:
  34. linear_b = torch.zeros_like(bn_rm)
  35. bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
  36. fused_w = linear_w * bn_scale.unsqueeze(-1)
  37. fused_b = (linear_b - bn_rm) * bn_scale + bn_b
  38. return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(fused_b, linear_b.requires_grad)