123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- // This file is part of Eigen, a lightweight C++ template library
- // for linear algebra.
- //
- // Copyright (C) 2015 Ke Yang <yangke@gmail.com>
- //
- // This Source Code Form is subject to the terms of the Mozilla
- // Public License v. 2.0. If a copy of the MPL was not distributed
- // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
- #include "main.h"
- #include <Eigen/CXX11/Tensor>
- using Eigen::Tensor;
- template<int DataLayout>
- static void test_simple_inflation()
- {
- Tensor<float, 4, DataLayout> tensor(2,3,5,7);
- tensor.setRandom();
- array<ptrdiff_t, 4> strides;
- strides[0] = 1;
- strides[1] = 1;
- strides[2] = 1;
- strides[3] = 1;
- Tensor<float, 4, DataLayout> no_stride;
- no_stride = tensor.inflate(strides);
- VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
- VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
- VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
- VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
- for (int i = 0; i < 2; ++i) {
- for (int j = 0; j < 3; ++j) {
- for (int k = 0; k < 5; ++k) {
- for (int l = 0; l < 7; ++l) {
- VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
- }
- }
- }
- }
- strides[0] = 2;
- strides[1] = 4;
- strides[2] = 2;
- strides[3] = 3;
- Tensor<float, 4, DataLayout> inflated;
- inflated = tensor.inflate(strides);
- VERIFY_IS_EQUAL(inflated.dimension(0), 3);
- VERIFY_IS_EQUAL(inflated.dimension(1), 9);
- VERIFY_IS_EQUAL(inflated.dimension(2), 9);
- VERIFY_IS_EQUAL(inflated.dimension(3), 19);
- for (int i = 0; i < 3; ++i) {
- for (int j = 0; j < 9; ++j) {
- for (int k = 0; k < 9; ++k) {
- for (int l = 0; l < 19; ++l) {
- if (i % 2 == 0 &&
- j % 4 == 0 &&
- k % 2 == 0 &&
- l % 3 == 0) {
- VERIFY_IS_EQUAL(inflated(i,j,k,l),
- tensor(i/2, j/4, k/2, l/3));
- } else {
- VERIFY_IS_EQUAL(0, inflated(i,j,k,l));
- }
- }
- }
- }
- }
- }
- EIGEN_DECLARE_TEST(cxx11_tensor_inflation)
- {
- CALL_SUBTEST(test_simple_inflation<ColMajor>());
- CALL_SUBTEST(test_simple_inflation<RowMajor>());
- }
|