quantization.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import functools
  2. import torch
  3. import torch.distributed as dist
  4. from enum import Enum
  5. TORCH_HALF_MIN = torch.finfo(torch.float16).min
  6. TORCH_HALF_MAX = torch.finfo(torch.float16).max
  7. class DQuantType(Enum):
  8. """
  9. Different quantization methods for auto_quantize API are identified here.
  10. auto_quantize API currently supports fp16 and bfp16 methods.
  11. """
  12. FP16 = "fp16",
  13. BFP16 = "bfp16"
  14. def __str__(self) -> str:
  15. return self.value
  16. def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
  17. return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
  18. def _quantize_tensor(tensor, qtype):
  19. if not isinstance(tensor, torch.Tensor):
  20. raise RuntimeError(
  21. f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
  22. )
  23. if (qtype == DQuantType.FP16):
  24. return _fp32_to_fp16_with_clamp(tensor)
  25. elif (qtype == DQuantType.BFP16):
  26. return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
  27. else:
  28. raise RuntimeError(
  29. f'Quantization type {qtype} is not supported'
  30. )
  31. def _quantize_tensor_list(tensor_list, qtype):
  32. if not isinstance(tensor_list, list) or not all(
  33. isinstance(p, torch.Tensor) for p in tensor_list
  34. ):
  35. raise RuntimeError(
  36. f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
  37. )
  38. quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
  39. return quantized_tensor_list
  40. def _dequantize_tensor(tensor, qtype, quant_loss=None):
  41. if not isinstance(tensor, torch.Tensor):
  42. raise RuntimeError(
  43. f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
  44. )
  45. if (qtype == DQuantType.FP16):
  46. if tensor.dtype != torch.float16:
  47. raise RuntimeError(
  48. f"tensor dtype is {tensor.dtype} while expected to be FP16."
  49. )
  50. elif tensor.dtype == torch.float16 and quant_loss is None:
  51. return tensor.float()
  52. else:
  53. return tensor.float() / quant_loss
  54. elif (qtype == DQuantType.BFP16):
  55. if tensor.dtype != torch.float16:
  56. raise RuntimeError(
  57. f"tensor dtype is {tensor.dtype} while expected to be FP16."
  58. )
  59. else:
  60. return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
  61. else:
  62. raise RuntimeError(
  63. f'Quantization type {qtype} is not supported'
  64. )
  65. def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
  66. if not isinstance(tensor_list, list) or not all(
  67. isinstance(p, torch.Tensor) for p in tensor_list
  68. ):
  69. raise RuntimeError(
  70. f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
  71. )
  72. dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
  73. return dequantized_tensor_list
  74. def auto_quantize(func, qtype, quant_loss=None):
  75. """
  76. This is a prototype API that automatically quantize the input tensors, choose the precision types, and
  77. pass other necessary arguments and then dequantizes the output.
  78. Currently it only supports:
  79. . FP16 and BFP16 quantization method supported for gloo and nccl backends
  80. . all_gather, all_to_all collective ops
  81. Note: BFP16 only supports 2D tensors.
  82. Args:
  83. func (Callable): A function representing collective operations.
  84. qtype (QuantType): Quantization method
  85. quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
  86. Returns:
  87. (Callable): the same collective as func but enables automatic quantization/dequantization.
  88. """
  89. @functools.wraps(func)
  90. def wrapper(*args, **kwargs):
  91. group = kwargs.get('group', None)
  92. async_op = kwargs.get('async_op', False)
  93. if (async_op is True):
  94. raise RuntimeError(
  95. 'The async_op=True mode is not supported yet.'
  96. )
  97. if (func == dist.all_gather):
  98. tensors = args[0]
  99. input_tensors = _quantize_tensor(args[1], qtype)
  100. out_tensors = _quantize_tensor_list(tensors, qtype)
  101. dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
  102. for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
  103. tensors[i] = t
  104. elif (func == dist.all_to_all):
  105. tensors = args[0]
  106. input_tensors = _quantize_tensor_list(args[1], qtype)
  107. out_tensors = _quantize_tensor_list(tensors, qtype)
  108. dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
  109. for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
  110. tensors[i] = t
  111. elif (func == dist.all_to_all_single):
  112. tensors = args[0]
  113. out_splits = kwargs.get('out_splits', None)
  114. in_splits = kwargs.get('in_splits', None)
  115. # Quantizing the input/output tensor
  116. input_tensors = _quantize_tensor(args[1], qtype)
  117. out_tensors = _quantize_tensor(tensors, qtype)
  118. dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group)
  119. for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)):
  120. tensors[i] = t
  121. else:
  122. raise RuntimeError(
  123. f"The collective op {func} is not supported yet"
  124. )
  125. return wrapper