native.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import List, Optional, Sequence, Union
  2. from torchgen import local
  3. from torchgen.api import cpp
  4. from torchgen.api.types import (
  5. ArgName,
  6. BaseCType,
  7. Binding,
  8. boolT,
  9. ConstRefCType,
  10. CType,
  11. deviceT,
  12. layoutT,
  13. ListCType,
  14. MutRefCType,
  15. NamedCType,
  16. OptionalCType,
  17. scalarT,
  18. scalarTypeT,
  19. tensorT,
  20. )
  21. from torchgen.model import (
  22. Argument,
  23. FunctionSchema,
  24. Return,
  25. SelfArgument,
  26. TensorOptionsArguments,
  27. Type,
  28. )
  29. from torchgen.utils import assert_never
  30. # This file describes the translation of JIT schema to the native functions API.
  31. # This looks a lot like the C++ API (which makes historical sense, because the
  32. # idea was you wrote native functions to implement functions in the C++ API),
  33. # but over time we have evolved the C++ API without actually changing our
  34. # native:: kernels. The intention is to make native API and dispatcher API
  35. # line up as closely as possible, since this results in the least overhead
  36. # (no translation is needed from dispatcher API to native API).
  37. #
  38. # NB: this is symint aware, you will get the non-SymInt variant for some
  39. # dispatch entries and SymInt for others.
  40. def name(func: FunctionSchema) -> str:
  41. name = str(func.name.name)
  42. # TODO: delete this!
  43. if func.is_out_fn():
  44. name += "_out"
  45. if func.name.overload_name:
  46. name += f"_{func.name.overload_name}"
  47. return name
  48. def argumenttype_type(
  49. t: Type, *, mutable: bool, binds: ArgName, symint: bool
  50. ) -> NamedCType:
  51. if str(t) == "Tensor?":
  52. tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
  53. if mutable and not local.use_const_ref_for_mutable_tensors():
  54. return NamedCType(binds, MutRefCType(tensor_type))
  55. else:
  56. return NamedCType(binds, ConstRefCType(tensor_type))
  57. elif str(t) == "Tensor?[]":
  58. return NamedCType(
  59. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  60. )
  61. elif str(t) == "Scalar":
  62. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  63. elif str(t) == "Scalar?":
  64. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  65. return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
  66. def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
  67. return cpp.returns_type(rs, symint=symint)
  68. def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
  69. return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
  70. def argument(
  71. a: Union[Argument, SelfArgument, TensorOptionsArguments],
  72. *,
  73. is_out: bool,
  74. symint: bool,
  75. ) -> List[Binding]:
  76. # Ideally, we NEVER default native functions. However, there are a number
  77. # of functions that call native:: directly and rely on the defaulting
  78. # existing. So for BC, we generate defaults for non-out variants (but not
  79. # for out variants, where it is impossible to generate an appropriate
  80. # default)
  81. should_default = not is_out
  82. if isinstance(a, Argument):
  83. default: Optional[str] = None
  84. if should_default and a.default is not None:
  85. default = cpp.default_expr(a.default, a.type, symint=symint)
  86. return [
  87. Binding(
  88. nctype=argument_type(a, binds=a.name, symint=symint),
  89. name=a.name,
  90. default=default,
  91. argument=a,
  92. )
  93. ]
  94. elif isinstance(a, SelfArgument):
  95. # Erase SelfArgument from the distinction
  96. return argument(a.argument, is_out=is_out, symint=symint)
  97. elif isinstance(a, TensorOptionsArguments):
  98. default = None
  99. if should_default:
  100. default = "{}"
  101. # TODO: Not sure why the arguments assigned here are for
  102. # TensorOptionsArguments and not the constituent pieces. It seems
  103. # to matter
  104. return [
  105. Binding(
  106. nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
  107. name="dtype",
  108. default=default,
  109. argument=a,
  110. ),
  111. Binding(
  112. nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
  113. name="layout",
  114. default=default,
  115. argument=a,
  116. ),
  117. Binding(
  118. nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
  119. name="device",
  120. default=default,
  121. argument=a,
  122. ),
  123. Binding(
  124. nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
  125. name="pin_memory",
  126. default=default,
  127. argument=a,
  128. ),
  129. ]
  130. else:
  131. assert_never(a)
  132. def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
  133. args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
  134. args.extend(func.arguments.non_out)
  135. args.extend(func.arguments.out)
  136. return [
  137. r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
  138. ]