common_extended_utils.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import os
  2. from collections import defaultdict
  3. from numbers import Number
  4. from typing import Any, List
  5. import torch
  6. from torch.utils._python_dispatch import TorchDispatchMode
  7. from torch.utils._pytree import tree_map
  8. from torchvision.models._api import Weights
  9. aten = torch.ops.aten
  10. quantized = torch.ops.quantized
  11. def get_shape(i):
  12. if isinstance(i, torch.Tensor):
  13. return i.shape
  14. elif hasattr(i, "weight"):
  15. return i.weight().shape
  16. else:
  17. raise ValueError(f"Unknown type {type(i)}")
  18. def prod(x):
  19. res = 1
  20. for i in x:
  21. res *= i
  22. return res
  23. def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
  24. """
  25. Count flops for matmul.
  26. """
  27. # Inputs should be a list of length 2.
  28. # Inputs contains the shapes of two matrices.
  29. input_shapes = [get_shape(v) for v in inputs]
  30. assert len(input_shapes) == 2, input_shapes
  31. assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
  32. flop = prod(input_shapes[0]) * input_shapes[-1][-1]
  33. return flop
  34. def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
  35. """
  36. Count flops for fully connected layers.
  37. """
  38. # Count flop for nn.Linear
  39. # inputs is a list of length 3.
  40. input_shapes = [get_shape(v) for v in inputs[1:3]]
  41. # input_shapes[0]: [batch size, input feature dimension]
  42. # input_shapes[1]: [batch size, output feature dimension]
  43. assert len(input_shapes[0]) == 2, input_shapes[0]
  44. assert len(input_shapes[1]) == 2, input_shapes[1]
  45. batch_size, input_dim = input_shapes[0]
  46. output_dim = input_shapes[1][1]
  47. flops = batch_size * input_dim * output_dim
  48. return flops
  49. def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
  50. """
  51. Count flops for the bmm operation.
  52. """
  53. # Inputs should be a list of length 2.
  54. # Inputs contains the shapes of two tensor.
  55. assert len(inputs) == 2, len(inputs)
  56. input_shapes = [get_shape(v) for v in inputs]
  57. n, c, t = input_shapes[0]
  58. d = input_shapes[-1][-1]
  59. flop = n * c * t * d
  60. return flop
  61. def conv_flop_count(
  62. x_shape: List[int],
  63. w_shape: List[int],
  64. out_shape: List[int],
  65. transposed: bool = False,
  66. ) -> Number:
  67. """
  68. Count flops for convolution. Note only multiplication is
  69. counted. Computation for addition and bias is ignored.
  70. Flops for a transposed convolution are calculated as
  71. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  72. Args:
  73. x_shape (list(int)): The input shape before convolution.
  74. w_shape (list(int)): The filter shape.
  75. out_shape (list(int)): The output shape after convolution.
  76. transposed (bool): is the convolution transposed
  77. Returns:
  78. int: the number of flops
  79. """
  80. batch_size = x_shape[0]
  81. conv_shape = (x_shape if transposed else out_shape)[2:]
  82. flop = batch_size * prod(w_shape) * prod(conv_shape)
  83. return flop
  84. def conv_flop(inputs: List[Any], outputs: List[Any]):
  85. """
  86. Count flops for convolution.
  87. """
  88. x, w = inputs[:2]
  89. x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
  90. transposed = inputs[6]
  91. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  92. def quant_conv_flop(inputs: List[Any], outputs: List[Any]):
  93. """
  94. Count flops for quantized convolution.
  95. """
  96. x, w = inputs[:2]
  97. x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
  98. return conv_flop_count(x_shape, w_shape, out_shape, transposed=False)
  99. def transpose_shape(shape):
  100. return [shape[1], shape[0]] + list(shape[2:])
  101. def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
  102. grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
  103. output_mask = inputs[-1]
  104. fwd_transposed = inputs[7]
  105. flop_count = 0
  106. if output_mask[0]:
  107. grad_input_shape = get_shape(outputs[0])
  108. flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
  109. if output_mask[1]:
  110. grad_weight_shape = get_shape(outputs[1])
  111. flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
  112. return flop_count
  113. def scaled_dot_product_flash_attention_flop(inputs: List[Any], outputs: List[Any]):
  114. # FIXME: this needs to count the flops of this kernel
  115. # https://github.com/pytorch/pytorch/blob/207b06d099def9d9476176a1842e88636c1f714f/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L52-L267
  116. return 0
  117. flop_mapping = {
  118. aten.mm: matmul_flop,
  119. aten.matmul: matmul_flop,
  120. aten.addmm: addmm_flop,
  121. aten.bmm: bmm_flop,
  122. aten.convolution: conv_flop,
  123. aten._convolution: conv_flop,
  124. aten.convolution_backward: conv_backward_flop,
  125. quantized.conv2d: quant_conv_flop,
  126. quantized.conv2d_relu: quant_conv_flop,
  127. aten._scaled_dot_product_flash_attention: scaled_dot_product_flash_attention_flop,
  128. }
  129. unmapped_ops = set()
  130. def normalize_tuple(x):
  131. if not isinstance(x, tuple):
  132. return (x,)
  133. return x
  134. class FlopCounterMode(TorchDispatchMode):
  135. def __init__(self, model=None):
  136. self.flop_counts = defaultdict(lambda: defaultdict(int))
  137. self.parents = ["Global"]
  138. # global mod
  139. if model is not None:
  140. for name, module in dict(model.named_children()).items():
  141. module.register_forward_pre_hook(self.enter_module(name))
  142. module.register_forward_hook(self.exit_module(name))
  143. def enter_module(self, name):
  144. def f(module, inputs):
  145. self.parents.append(name)
  146. inputs = normalize_tuple(inputs)
  147. out = self.create_backwards_pop(name)(*inputs)
  148. return out
  149. return f
  150. def exit_module(self, name):
  151. def f(module, inputs, outputs):
  152. assert self.parents[-1] == name
  153. self.parents.pop()
  154. outputs = normalize_tuple(outputs)
  155. return self.create_backwards_push(name)(*outputs)
  156. return f
  157. def create_backwards_push(self, name):
  158. class PushState(torch.autograd.Function):
  159. @staticmethod
  160. def forward(ctx, *args):
  161. args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
  162. if len(args) == 1:
  163. return args[0]
  164. return args
  165. @staticmethod
  166. def backward(ctx, *grad_outs):
  167. self.parents.append(name)
  168. return grad_outs
  169. return PushState.apply
  170. def create_backwards_pop(self, name):
  171. class PopState(torch.autograd.Function):
  172. @staticmethod
  173. def forward(ctx, *args):
  174. args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
  175. if len(args) == 1:
  176. return args[0]
  177. return args
  178. @staticmethod
  179. def backward(ctx, *grad_outs):
  180. assert self.parents[-1] == name
  181. self.parents.pop()
  182. return grad_outs
  183. return PopState.apply
  184. def __enter__(self):
  185. self.flop_counts.clear()
  186. super().__enter__()
  187. def __exit__(self, *args):
  188. # print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS")
  189. # for mod in self.flop_counts.keys():
  190. # print(f"Module: ", mod)
  191. # for k, v in self.flop_counts[mod].items():
  192. # print(f"{k}: {v / 1e9} GFLOPS")
  193. # print()
  194. super().__exit__(*args)
  195. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  196. kwargs = kwargs if kwargs else {}
  197. out = func(*args, **kwargs)
  198. func_packet = func._overloadpacket
  199. if func_packet in flop_mapping:
  200. flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
  201. for par in self.parents:
  202. self.flop_counts[par][func_packet] += flop_count
  203. else:
  204. unmapped_ops.add(func_packet)
  205. return out
  206. def get_flops(self):
  207. return sum(self.flop_counts["Global"].values()) / 1e9
  208. def get_dims(module_name, height, width):
  209. # detection models have curated input sizes
  210. if module_name == "detection":
  211. # we can feed a batch of 1 for detection model instead of a list of 1 image
  212. dims = (3, height, width)
  213. elif module_name == "video":
  214. # hard-coding the time dimension to size 16
  215. dims = (1, 16, 3, height, width)
  216. else:
  217. dims = (1, 3, height, width)
  218. return dims
  219. def get_ops(model: torch.nn.Module, weight: Weights, height=512, width=512):
  220. module_name = model.__module__.split(".")[-2]
  221. dims = get_dims(module_name=module_name, height=height, width=width)
  222. input_tensor = torch.randn(dims)
  223. # try:
  224. preprocess = weight.transforms()
  225. if module_name == "optical_flow":
  226. inp = preprocess(input_tensor, input_tensor)
  227. else:
  228. # hack to enable mod(*inp) for optical_flow models
  229. inp = [preprocess(input_tensor)]
  230. model.eval()
  231. flop_counter = FlopCounterMode(model)
  232. with flop_counter:
  233. # detection models expect a list of 3d tensors as inputs
  234. if module_name == "detection":
  235. model(inp)
  236. else:
  237. model(*inp)
  238. flops = flop_counter.get_flops()
  239. return round(flops, 3)
  240. def get_file_size_mb(weight):
  241. weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints", weight.url.split("/")[-1])
  242. weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024
  243. return round(weights_size_mb, 3)