utils.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import threading
  2. __all__ = [
  3. "LinearBlockSparsePattern"
  4. ]
  5. def _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size):
  6. return (row_block_size == 1 and col_block_size == 4) or \
  7. (row_block_size == 8 and col_block_size == 1)
  8. # This is a stop-gap measure as current flow does not allow module
  9. # specific block sparse pattern.
  10. # Infact there is no way to convey sparse pattern via module config
  11. # of quantization flow. Thus using the global context to convey
  12. # sparsity pattern.
  13. # Once the flow supports it, this should be removed.
  14. class LinearBlockSparsePattern:
  15. rlock = threading.RLock()
  16. row_block_size = 1
  17. col_block_size = 4
  18. prev_row_block_size = 1
  19. prev_col_block_size = 4
  20. def __init__(self, row_block_size=1, col_block_size=4):
  21. assert(_is_valid_linear_block_sparse_pattern(row_block_size, col_block_size))
  22. LinearBlockSparsePattern.rlock.acquire()
  23. LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size
  24. LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size
  25. LinearBlockSparsePattern.row_block_size = row_block_size
  26. LinearBlockSparsePattern.col_block_size = col_block_size
  27. def __enter__(self):
  28. pass
  29. def __exit__(self, exc_type, exc_value, backtrace):
  30. LinearBlockSparsePattern.row_block_size = LinearBlockSparsePattern.prev_row_block_size
  31. LinearBlockSparsePattern.col_block_size = LinearBlockSparsePattern.prev_col_block_size
  32. LinearBlockSparsePattern.rlock.release()
  33. @staticmethod
  34. def block_size():
  35. return LinearBlockSparsePattern.row_block_size, LinearBlockSparsePattern.col_block_size