_ni_support.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (C) 2003-2005 Peter J. Verveer
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. #
  7. # 1. Redistributions of source code must retain the above copyright
  8. # notice, this list of conditions and the following disclaimer.
  9. #
  10. # 2. Redistributions in binary form must reproduce the above
  11. # copyright notice, this list of conditions and the following
  12. # disclaimer in the documentation and/or other materials provided
  13. # with the distribution.
  14. #
  15. # 3. The name of the author may not be used to endorse or promote
  16. # products derived from this software without specific prior
  17. # written permission.
  18. #
  19. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
  20. # OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  21. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
  23. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
  25. # GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
  27. # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  28. # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  29. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. from collections.abc import Iterable
  31. import warnings
  32. import numpy
  33. def _extend_mode_to_code(mode):
  34. """Convert an extension mode to the corresponding integer code.
  35. """
  36. if mode == 'nearest':
  37. return 0
  38. elif mode == 'wrap':
  39. return 1
  40. elif mode in ['reflect', 'grid-mirror']:
  41. return 2
  42. elif mode == 'mirror':
  43. return 3
  44. elif mode == 'constant':
  45. return 4
  46. elif mode == 'grid-wrap':
  47. return 5
  48. elif mode == 'grid-constant':
  49. return 6
  50. else:
  51. raise RuntimeError('boundary mode not supported')
  52. def _normalize_sequence(input, rank):
  53. """If input is a scalar, create a sequence of length equal to the
  54. rank by duplicating the input. If input is a sequence,
  55. check if its length is equal to the length of array.
  56. """
  57. is_str = isinstance(input, str)
  58. if not is_str and isinstance(input, Iterable):
  59. normalized = list(input)
  60. if len(normalized) != rank:
  61. err = "sequence argument must have length equal to input rank"
  62. raise RuntimeError(err)
  63. else:
  64. normalized = [input] * rank
  65. return normalized
  66. def _get_output(output, input, shape=None, complex_output=False):
  67. if shape is None:
  68. shape = input.shape
  69. if output is None:
  70. if not complex_output:
  71. output = numpy.zeros(shape, dtype=input.dtype.name)
  72. else:
  73. complex_type = numpy.promote_types(input.dtype, numpy.complex64)
  74. output = numpy.zeros(shape, dtype=complex_type)
  75. elif isinstance(output, (type, numpy.dtype)):
  76. # Classes (like `np.float32`) and dtypes are interpreted as dtype
  77. if complex_output and numpy.dtype(output).kind != 'c':
  78. warnings.warn("promoting specified output dtype to complex")
  79. output = numpy.promote_types(output, numpy.complex64)
  80. output = numpy.zeros(shape, dtype=output)
  81. elif isinstance(output, str):
  82. output = numpy.sctypeDict[output]
  83. if complex_output and numpy.dtype(output).kind != 'c':
  84. raise RuntimeError("output must have complex dtype")
  85. output = numpy.zeros(shape, dtype=output)
  86. elif output.shape != shape:
  87. raise RuntimeError("output shape not correct")
  88. elif complex_output and output.dtype.kind != 'c':
  89. raise RuntimeError("output must have complex dtype")
  90. return output