12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- #include <c10/util/Exception.h>
- #include <utility>
- namespace at {
- /*
- [collapse dims] Updates sizes, and strides to reflect a "collapse" of
- the info, possibly excluding the optional excludeDim. A "collapsed" version
- of the info is the fewest dims that order the tensor's elements in the same
- way as the original info. If excludeDim is specified, the collapse is the
- fewest dims that order the tensor's elements as the original and preserve the
- excluded dimension, unless the tensor collapses to a point.
- This function returns a pair of values.
- 1) The (new) index of the preserved dimension if excludeDim is
- specified. 0 if the tensor is collapsed to a point. -1
- otherwise.
- 2) The new number of dimensions.
- */
- template <typename T>
- inline std::pair<int64_t, int64_t> collapse_dims(
- T* sizes,
- T* strides,
- int64_t dims,
- const int excludeDim = -1) {
- TORCH_CHECK(
- excludeDim >= -1 && excludeDim < dims,
- "expected excluded dim between -1 and dims - 1");
- int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
- int64_t newIndex = -1;
- int64_t oldIndex = 0;
- int64_t remappedExcludedDim = -1;
- while (oldIndex < dims) {
- // Finds a dimension to collapse into
- for (; oldIndex < stopDim; ++oldIndex) {
- if (sizes[oldIndex] == 1) {
- continue;
- }
- ++newIndex;
- sizes[newIndex] = sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- ++oldIndex;
- break;
- }
- // Collapses dims
- for (; oldIndex < stopDim; ++oldIndex) {
- if (sizes[oldIndex] == 1) {
- continue;
- }
- if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
- sizes[newIndex] *= sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- } else {
- ++newIndex;
- sizes[newIndex] = sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- }
- }
- // Handles excludeDim being set (oldIndex == excludeDim)
- if (oldIndex != dims) {
- // Preserves excluded dimension
- ++newIndex;
- sizes[newIndex] = sizes[oldIndex];
- strides[newIndex] = strides[oldIndex];
- remappedExcludedDim = newIndex;
- // Restarts iteration after excludeDim
- ++oldIndex;
- stopDim = dims;
- }
- }
- // Handles special case of all dims size 1
- if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
- dims = 1;
- sizes[0] = 1;
- strides[0] = 1;
- return std::pair<int64_t, int64_t>(0, 1);
- }
- dims = newIndex + 1;
- return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
- }
- } // namespace at
|