memory_format.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch
  2. def convert_conv2d_weight_memory_format(module, memory_format):
  3. r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``
  4. The conversion recursively applies to nested ``nn.Module``, including ``module``.
  5. Note that it only changes the memory_format, but not the semantics of each dimensions.
  6. This function is used to facilitate the computation to adopt NHWC kernels, which
  7. provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
  8. .. note::
  9. Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive
  10. than the utility function ``convert_conv2d_weight_memory_format``. Any
  11. layer with 4d weight will be affected by ``model.to``, which does not
  12. necessarily benefit from conversion to specified ``memory_format``.
  13. One place we are confident in is that NHWC(channels_last) conversion for
  14. convolution in cuDNN, As it is beneficial to run convolution in NHWC,
  15. even in cases where we have to apply permutation to input tensors.
  16. Hence our strategy here is to convert only the weight of convolution to
  17. channels_last. This ensures that;
  18. 1. Fast convolution kernels will be used, the benefit of which could
  19. outweigh overhead of permutation (if input is not in the same format)
  20. 2. No unnecessary permutations are applied on layers that do not benefit
  21. from memory_format conversion.
  22. The optimal case is that, layers between convolution layers are channels
  23. last compatible. Input tensor would be permuted to channels last when it
  24. encounters the first convolution layer and stay in that memory format.
  25. Hence following convolutions will not need to permute its input tensor.
  26. In case where a channels last incompatible layer is between convolution
  27. layers, we need to permute the input tensor back to contiguous format
  28. for that layer. The input tensor will go through the remaining layers in
  29. contiguous format and be permuted to channels last when it encounters
  30. another convolution layer. There's no point in propagating that
  31. permutation to an earlier layer, as most layers are quite agnostic to
  32. ``memory_format``.
  33. This claim might change when PyTorch supports fusion of permutation, as
  34. there might have been a better spot to fuse the permutation other than
  35. immediately before a convolution.
  36. Args:
  37. module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container
  38. ``nn.Module``
  39. memory_format: user specified ``memory_format``,
  40. e.g. ``torch.channels_last`` or ``torch.contiguous_format``
  41. Returns:
  42. The original module with updated ``nn.Conv2d``
  43. Example:
  44. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  45. >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
  46. >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda")
  47. >>> model = nn.Sequential(
  48. >>> nn.Conv2d(8, 4, 3)).cuda().half()
  49. >>> # This is identical to:
  50. >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
  51. >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
  52. >>> out = model(input)
  53. """
  54. # TODO: expand this to `_ConvNd` when channels_last support is extended
  55. # beyond only 4d tensors.
  56. if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
  57. weight_data = module.weight.detach().clone().contiguous(memory_format=memory_format)
  58. module.weight.data = weight_data.resize_(weight_data.size(), memory_format=memory_format)
  59. for child in module.children():
  60. convert_conv2d_weight_memory_format(child, memory_format)
  61. return module