import torch def _extract_strides(shape): rank = len(shape) ret = [1] * rank for i in range(rank - 1, 0, -1): ret[i - 1] = ret[i] * shape[i] return ret def _roundup(x, div): return (x + div - 1) // div * div # unpack the given idx given the order of axis of the desired 3-dim tensor # You could view it as the reverse of flatten the idx of 3 axis in a tensor to 1-dim idx. # order is the order of axes in tensor, innermost dimension outward # shape is the 3D tensor's shape def _unpack(idx, order, shape): if torch.is_tensor(idx): _12 = torch.div(idx, shape[order[0]], rounding_mode="trunc") _0 = idx % shape[order[0]] _2 = torch.div(_12, shape[order[1]], rounding_mode="trunc") _1 = _12 % shape[order[1]] else: _12 = idx // shape[order[0]] _0 = idx % shape[order[0]] _2 = _12 // shape[order[1]] _1 = _12 % shape[order[1]] return _0, _1, _2