utils.py 951 B

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. def _extract_strides(shape):
  3. rank = len(shape)
  4. ret = [1] * rank
  5. for i in range(rank - 1, 0, -1):
  6. ret[i - 1] = ret[i] * shape[i]
  7. return ret
  8. def _roundup(x, div):
  9. return (x + div - 1) // div * div
  10. # unpack the given idx given the order of axis of the desired 3-dim tensor
  11. # You could view it as the reverse of flatten the idx of 3 axis in a tensor to 1-dim idx.
  12. # order is the order of axes in tensor, innermost dimension outward
  13. # shape is the 3D tensor's shape
  14. def _unpack(idx, order, shape):
  15. if torch.is_tensor(idx):
  16. _12 = torch.div(idx, shape[order[0]], rounding_mode="trunc")
  17. _0 = idx % shape[order[0]]
  18. _2 = torch.div(_12, shape[order[1]], rounding_mode="trunc")
  19. _1 = _12 % shape[order[1]]
  20. else:
  21. _12 = idx // shape[order[0]]
  22. _0 = idx % shape[order[0]]
  23. _2 = _12 // shape[order[1]]
  24. _1 = _12 % shape[order[1]]
  25. return _0, _1, _2