generate_numpy_api.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import os
  2. import genapi
  3. from genapi import \
  4. TypeApi, GlobalVarApi, FunctionApi, BoolValuesApi
  5. import numpy_api
  6. # use annotated api when running under cpychecker
  7. h_template = r"""
  8. #if defined(_MULTIARRAYMODULE) || defined(WITH_CPYCHECKER_STEALS_REFERENCE_TO_ARG_ATTRIBUTE)
  9. typedef struct {
  10. PyObject_HEAD
  11. npy_bool obval;
  12. } PyBoolScalarObject;
  13. extern NPY_NO_EXPORT PyTypeObject PyArrayMapIter_Type;
  14. extern NPY_NO_EXPORT PyTypeObject PyArrayNeighborhoodIter_Type;
  15. extern NPY_NO_EXPORT PyBoolScalarObject _PyArrayScalar_BoolValues[2];
  16. %s
  17. #else
  18. #if defined(PY_ARRAY_UNIQUE_SYMBOL)
  19. #define PyArray_API PY_ARRAY_UNIQUE_SYMBOL
  20. #endif
  21. #if defined(NO_IMPORT) || defined(NO_IMPORT_ARRAY)
  22. extern void **PyArray_API;
  23. #else
  24. #if defined(PY_ARRAY_UNIQUE_SYMBOL)
  25. void **PyArray_API;
  26. #else
  27. static void **PyArray_API=NULL;
  28. #endif
  29. #endif
  30. %s
  31. #if !defined(NO_IMPORT_ARRAY) && !defined(NO_IMPORT)
  32. static int
  33. _import_array(void)
  34. {
  35. int st;
  36. PyObject *numpy = PyImport_ImportModule("numpy.core._multiarray_umath");
  37. PyObject *c_api = NULL;
  38. if (numpy == NULL) {
  39. return -1;
  40. }
  41. c_api = PyObject_GetAttrString(numpy, "_ARRAY_API");
  42. Py_DECREF(numpy);
  43. if (c_api == NULL) {
  44. PyErr_SetString(PyExc_AttributeError, "_ARRAY_API not found");
  45. return -1;
  46. }
  47. if (!PyCapsule_CheckExact(c_api)) {
  48. PyErr_SetString(PyExc_RuntimeError, "_ARRAY_API is not PyCapsule object");
  49. Py_DECREF(c_api);
  50. return -1;
  51. }
  52. PyArray_API = (void **)PyCapsule_GetPointer(c_api, NULL);
  53. Py_DECREF(c_api);
  54. if (PyArray_API == NULL) {
  55. PyErr_SetString(PyExc_RuntimeError, "_ARRAY_API is NULL pointer");
  56. return -1;
  57. }
  58. /* Perform runtime check of C API version */
  59. if (NPY_VERSION != PyArray_GetNDArrayCVersion()) {
  60. PyErr_Format(PyExc_RuntimeError, "module compiled against "\
  61. "ABI version 0x%%x but this version of numpy is 0x%%x", \
  62. (int) NPY_VERSION, (int) PyArray_GetNDArrayCVersion());
  63. return -1;
  64. }
  65. if (NPY_FEATURE_VERSION > PyArray_GetNDArrayCFeatureVersion()) {
  66. PyErr_Format(PyExc_RuntimeError, "module compiled against "\
  67. "API version 0x%%x but this version of numpy is 0x%%x . "\
  68. "Check the section C-API incompatibility at the "\
  69. "Troubleshooting ImportError section at "\
  70. "https://numpy.org/devdocs/user/troubleshooting-importerror.html"\
  71. "#c-api-incompatibility "\
  72. "for indications on how to solve this problem .", \
  73. (int) NPY_FEATURE_VERSION, (int) PyArray_GetNDArrayCFeatureVersion());
  74. return -1;
  75. }
  76. /*
  77. * Perform runtime check of endianness and check it matches the one set by
  78. * the headers (npy_endian.h) as a safeguard
  79. */
  80. st = PyArray_GetEndianness();
  81. if (st == NPY_CPU_UNKNOWN_ENDIAN) {
  82. PyErr_SetString(PyExc_RuntimeError,
  83. "FATAL: module compiled as unknown endian");
  84. return -1;
  85. }
  86. #if NPY_BYTE_ORDER == NPY_BIG_ENDIAN
  87. if (st != NPY_CPU_BIG) {
  88. PyErr_SetString(PyExc_RuntimeError,
  89. "FATAL: module compiled as big endian, but "
  90. "detected different endianness at runtime");
  91. return -1;
  92. }
  93. #elif NPY_BYTE_ORDER == NPY_LITTLE_ENDIAN
  94. if (st != NPY_CPU_LITTLE) {
  95. PyErr_SetString(PyExc_RuntimeError,
  96. "FATAL: module compiled as little endian, but "
  97. "detected different endianness at runtime");
  98. return -1;
  99. }
  100. #endif
  101. return 0;
  102. }
  103. #define import_array() {if (_import_array() < 0) {PyErr_Print(); PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import"); return NULL; } }
  104. #define import_array1(ret) {if (_import_array() < 0) {PyErr_Print(); PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import"); return ret; } }
  105. #define import_array2(msg, ret) {if (_import_array() < 0) {PyErr_Print(); PyErr_SetString(PyExc_ImportError, msg); return ret; } }
  106. #endif
  107. #endif
  108. """
  109. c_template = r"""
  110. /* These pointers will be stored in the C-object for use in other
  111. extension modules
  112. */
  113. void *PyArray_API[] = {
  114. %s
  115. };
  116. """
  117. c_api_header = """
  118. ===========
  119. NumPy C-API
  120. ===========
  121. """
  122. def generate_api(output_dir, force=False):
  123. basename = 'multiarray_api'
  124. h_file = os.path.join(output_dir, '__%s.h' % basename)
  125. c_file = os.path.join(output_dir, '__%s.c' % basename)
  126. d_file = os.path.join(output_dir, '%s.txt' % basename)
  127. targets = (h_file, c_file, d_file)
  128. sources = numpy_api.multiarray_api
  129. if (not force and not genapi.should_rebuild(targets, [numpy_api.__file__, __file__])):
  130. return targets
  131. else:
  132. do_generate_api(targets, sources)
  133. return targets
  134. def do_generate_api(targets, sources):
  135. header_file = targets[0]
  136. c_file = targets[1]
  137. doc_file = targets[2]
  138. global_vars = sources[0]
  139. scalar_bool_values = sources[1]
  140. types_api = sources[2]
  141. multiarray_funcs = sources[3]
  142. multiarray_api = sources[:]
  143. module_list = []
  144. extension_list = []
  145. init_list = []
  146. # Check multiarray api indexes
  147. multiarray_api_index = genapi.merge_api_dicts(multiarray_api)
  148. genapi.check_api_dict(multiarray_api_index)
  149. numpyapi_list = genapi.get_api_functions('NUMPY_API',
  150. multiarray_funcs)
  151. # Create dict name -> *Api instance
  152. api_name = 'PyArray_API'
  153. multiarray_api_dict = {}
  154. for f in numpyapi_list:
  155. name = f.name
  156. index = multiarray_funcs[name][0]
  157. annotations = multiarray_funcs[name][1:]
  158. multiarray_api_dict[f.name] = FunctionApi(f.name, index, annotations,
  159. f.return_type,
  160. f.args, api_name)
  161. for name, val in global_vars.items():
  162. index, type = val
  163. multiarray_api_dict[name] = GlobalVarApi(name, index, type, api_name)
  164. for name, val in scalar_bool_values.items():
  165. index = val[0]
  166. multiarray_api_dict[name] = BoolValuesApi(name, index, api_name)
  167. for name, val in types_api.items():
  168. index = val[0]
  169. internal_type = None if len(val) == 1 else val[1]
  170. multiarray_api_dict[name] = TypeApi(
  171. name, index, 'PyTypeObject', api_name, internal_type)
  172. if len(multiarray_api_dict) != len(multiarray_api_index):
  173. keys_dict = set(multiarray_api_dict.keys())
  174. keys_index = set(multiarray_api_index.keys())
  175. raise AssertionError(
  176. "Multiarray API size mismatch - "
  177. "index has extra keys {}, dict has extra keys {}"
  178. .format(keys_index - keys_dict, keys_dict - keys_index)
  179. )
  180. extension_list = []
  181. for name, index in genapi.order_dict(multiarray_api_index):
  182. api_item = multiarray_api_dict[name]
  183. extension_list.append(api_item.define_from_array_api_string())
  184. init_list.append(api_item.array_api_define())
  185. module_list.append(api_item.internal_define())
  186. # Write to header
  187. s = h_template % ('\n'.join(module_list), '\n'.join(extension_list))
  188. genapi.write_file(header_file, s)
  189. # Write to c-code
  190. s = c_template % ',\n'.join(init_list)
  191. genapi.write_file(c_file, s)
  192. # write to documentation
  193. s = c_api_header
  194. for func in numpyapi_list:
  195. s += func.to_ReST()
  196. s += '\n\n'
  197. genapi.write_file(doc_file, s)
  198. return targets