123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import functools
- import torch
- import torch.distributed as dist
- from enum import Enum
- TORCH_HALF_MIN = torch.finfo(torch.float16).min
- TORCH_HALF_MAX = torch.finfo(torch.float16).max
- class DQuantType(Enum):
- """
- Different quantization methods for auto_quantize API are identified here.
- auto_quantize API currently supports fp16 and bfp16 methods.
- """
- FP16 = "fp16",
- BFP16 = "bfp16"
- def __str__(self) -> str:
- return self.value
- def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
- return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
- def _quantize_tensor(tensor, qtype):
- if not isinstance(tensor, torch.Tensor):
- raise RuntimeError(
- f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
- )
- if (qtype == DQuantType.FP16):
- return _fp32_to_fp16_with_clamp(tensor)
- elif (qtype == DQuantType.BFP16):
- return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
- else:
- raise RuntimeError(
- f'Quantization type {qtype} is not supported'
- )
- def _quantize_tensor_list(tensor_list, qtype):
- if not isinstance(tensor_list, list) or not all(
- isinstance(p, torch.Tensor) for p in tensor_list
- ):
- raise RuntimeError(
- f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
- )
- quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
- return quantized_tensor_list
- def _dequantize_tensor(tensor, qtype, quant_loss=None):
- if not isinstance(tensor, torch.Tensor):
- raise RuntimeError(
- f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
- )
- if (qtype == DQuantType.FP16):
- if tensor.dtype != torch.float16:
- raise RuntimeError(
- f"tensor dtype is {tensor.dtype} while expected to be FP16."
- )
- elif tensor.dtype == torch.float16 and quant_loss is None:
- return tensor.float()
- else:
- return tensor.float() / quant_loss
- elif (qtype == DQuantType.BFP16):
- if tensor.dtype != torch.float16:
- raise RuntimeError(
- f"tensor dtype is {tensor.dtype} while expected to be FP16."
- )
- else:
- return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
- else:
- raise RuntimeError(
- f'Quantization type {qtype} is not supported'
- )
- def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
- if not isinstance(tensor_list, list) or not all(
- isinstance(p, torch.Tensor) for p in tensor_list
- ):
- raise RuntimeError(
- f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
- )
- dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
- return dequantized_tensor_list
- def auto_quantize(func, qtype, quant_loss=None):
- """
- This is a prototype API that automatically quantize the input tensors, choose the precision types, and
- pass other necessary arguments and then dequantizes the output.
- Currently it only supports:
- . FP16 and BFP16 quantization method supported for gloo and nccl backends
- . all_gather, all_to_all collective ops
- Note: BFP16 only supports 2D tensors.
- Args:
- func (Callable): A function representing collective operations.
- qtype (QuantType): Quantization method
- quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
- Returns:
- (Callable): the same collective as func but enables automatic quantization/dequantization.
- """
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- group = kwargs.get('group', None)
- async_op = kwargs.get('async_op', False)
- if (async_op is True):
- raise RuntimeError(
- 'The async_op=True mode is not supported yet.'
- )
- if (func == dist.all_gather):
- tensors = args[0]
- input_tensors = _quantize_tensor(args[1], qtype)
- out_tensors = _quantize_tensor_list(tensors, qtype)
- dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
- for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
- tensors[i] = t
- elif (func == dist.all_to_all):
- tensors = args[0]
- input_tensors = _quantize_tensor_list(args[1], qtype)
- out_tensors = _quantize_tensor_list(tensors, qtype)
- dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
- for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
- tensors[i] = t
- elif (func == dist.all_to_all_single):
- tensors = args[0]
- out_splits = kwargs.get('out_splits', None)
- in_splits = kwargs.get('in_splits', None)
- # Quantizing the input/output tensor
- input_tensors = _quantize_tensor(args[1], qtype)
- out_tensors = _quantize_tensor(tensors, qtype)
- dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group)
- for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)):
- tensors[i] = t
- else:
- raise RuntimeError(
- f"The collective op {func} is not supported yet"
- )
- return wrapper
|