grad.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. """Gradient interface"""
  2. import torch
  3. from .modules.utils import _single, _pair, _triple
  4. def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
  5. r"""
  6. Computes the gradient of conv1d with respect to the input of the convolution.
  7. This is same as the 1D transposed convolution operator under the hood but requires
  8. the shape of the gradient w.r.t. input to be specified explicitly.
  9. Args:
  10. input_size : Shape of the input gradient tensor
  11. weight: weight tensor (out_channels x in_channels/groups x kW)
  12. grad_output : output gradient tensor (minibatch x out_channels x oW)
  13. stride (int or tuple, optional): Stride of the convolution. Default: 1
  14. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  15. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  16. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  17. Examples::
  18. >>> input = torch.randn(1, 1, 3, requires_grad=True)
  19. >>> weight = torch.randn(1, 1, 1, requires_grad=True)
  20. >>> output = F.conv1d(input, weight)
  21. >>> grad_output = torch.randn(output.shape)
  22. >>> grad_input = torch.autograd.grad(output, input, grad_output)
  23. >>> F.grad.conv1d_input(input.shape, weight, grad_output)
  24. """
  25. input = grad_output.new_empty(1).expand(input_size)
  26. return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
  27. _single(stride), _single(padding), _single(dilation),
  28. False, [0], groups, (True, False, False))[0]
  29. def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
  30. r"""
  31. Computes the gradient of conv1d with respect to the weight of the convolution.
  32. Args:
  33. input: input tensor of shape (minibatch x in_channels x iW)
  34. weight_size : Shape of the weight gradient tensor
  35. grad_output : output gradient tensor (minibatch x out_channels x oW)
  36. stride (int or tuple, optional): Stride of the convolution. Default: 1
  37. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  38. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  39. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  40. Examples::
  41. >>> input = torch.randn(1, 1, 3, requires_grad=True)
  42. >>> weight = torch.randn(1, 1, 1, requires_grad=True)
  43. >>> output = F.conv1d(input, weight)
  44. >>> grad_output = torch.randn(output.shape)
  45. >>> # xdoctest: +SKIP
  46. >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
  47. >>> F.grad.conv1d_weight(input, weight.shape, grad_output)
  48. """
  49. weight = grad_output.new_empty(1).expand(weight_size)
  50. return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
  51. _single(stride), _single(padding), _single(dilation),
  52. False, [0], groups, (False, True, False))[1]
  53. def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
  54. r"""
  55. Computes the gradient of conv2d with respect to the input of the convolution.
  56. This is same as the 2D transposed convolution operator under the hood but requires
  57. the shape of the gradient w.r.t. input to be specified explicitly.
  58. Args:
  59. input_size : Shape of the input gradient tensor
  60. weight: weight tensor (out_channels x in_channels/groups x kH x kW)
  61. grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
  62. stride (int or tuple, optional): Stride of the convolution. Default: 1
  63. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  64. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  65. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  66. Examples::
  67. >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
  68. >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
  69. >>> output = F.conv2d(input, weight)
  70. >>> grad_output = torch.randn(output.shape)
  71. >>> grad_input = torch.autograd.grad(output, input, grad_output)
  72. >>> F.grad.conv2d_input(input.shape, weight, grad_output)
  73. """
  74. input = grad_output.new_empty(1).expand(input_size)
  75. return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
  76. _pair(stride), _pair(padding), _pair(dilation),
  77. False, [0], groups, (True, False, False))[0]
  78. def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
  79. r"""
  80. Computes the gradient of conv2d with respect to the weight of the convolution.
  81. Args:
  82. input: input tensor of shape (minibatch x in_channels x iH x iW)
  83. weight_size : Shape of the weight gradient tensor
  84. grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
  85. stride (int or tuple, optional): Stride of the convolution. Default: 1
  86. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  87. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  88. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  89. Examples::
  90. >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
  91. >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
  92. >>> output = F.conv2d(input, weight)
  93. >>> grad_output = torch.randn(output.shape)
  94. >>> # xdoctest: +SKIP
  95. >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
  96. >>> F.grad.conv2d_weight(input, weight.shape, grad_output)
  97. """
  98. weight = grad_output.new_empty(1).expand(weight_size)
  99. return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
  100. _pair(stride), _pair(padding), _pair(dilation),
  101. False, [0], groups, (False, True, False))[1]
  102. def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
  103. r"""
  104. Computes the gradient of conv3d with respect to the input of the convolution.
  105. This is same as the 3D transposed convolution operator under the hood but requires
  106. the shape of the gradient w.r.t. input to be specified explicitly.
  107. Args:
  108. input_size : Shape of the input gradient tensor
  109. weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
  110. grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
  111. stride (int or tuple, optional): Stride of the convolution. Default: 1
  112. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  113. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  114. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  115. Examples::
  116. >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
  117. >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
  118. >>> output = F.conv3d(input, weight)
  119. >>> grad_output = torch.randn(output.shape)
  120. >>> grad_input = torch.autograd.grad(output, input, grad_output)
  121. >>> F.grad.conv3d_input(input.shape, weight, grad_output)
  122. """
  123. input = grad_output.new_empty(1).expand(input_size)
  124. return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
  125. _triple(stride), _triple(padding), _triple(dilation),
  126. False, [0], groups, (True, False, False))[0]
  127. def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
  128. r"""
  129. Computes the gradient of conv3d with respect to the weight of the convolution.
  130. Args:
  131. input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
  132. weight_size : Shape of the weight gradient tensor
  133. grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
  134. stride (int or tuple, optional): Stride of the convolution. Default: 1
  135. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  136. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  137. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  138. Examples::
  139. >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
  140. >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
  141. >>> output = F.conv3d(input, weight)
  142. >>> grad_output = torch.randn(output.shape)
  143. >>> grad_weight = torch.autograd.grad(output, weight, grad_output)
  144. >>> F.grad.conv3d_weight(input, weight.shape, grad_output)
  145. """
  146. weight = grad_output.new_empty(1).expand(weight_size)
  147. return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
  148. _triple(stride), _triple(padding), _triple(dilation),
  149. False, [0], groups, (False, True, False))[1]