dispatcher.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import itertools
  2. from typing import List, Sequence, Union
  3. from torchgen.api import cpp
  4. from torchgen.api.types import ArgName, Binding, CType, NamedCType
  5. from torchgen.model import (
  6. Argument,
  7. FunctionSchema,
  8. Return,
  9. SelfArgument,
  10. TensorOptionsArguments,
  11. Type,
  12. )
  13. from torchgen.utils import assert_never, concatMap
  14. # This file describes the translation of JIT schema to the dispatcher
  15. # API, the *unboxed* calling convention by which invocations through
  16. # the dispatcher are made. Historically, the dispatcher API matched
  17. # the C++ API, but with the establishment of the boxed API, we've
  18. # made changes to the dispatcher API to so that the unboxed API
  19. # better aligns with the boxed API. The dispatcher API hooks heavily
  20. # into our template based boxing/unboxing machinery, so changes
  21. # to this convention will usually need template updates too.
  22. #
  23. # Prominent characteristics of the dispatcher API:
  24. #
  25. # - dtype, layout, device and pin_memory are represented as separate
  26. # arguments.
  27. #
  28. def name(func: FunctionSchema) -> str:
  29. return cpp.name(func)
  30. def argumenttype_type(
  31. t: Type,
  32. *,
  33. mutable: bool,
  34. binds: ArgName,
  35. remove_non_owning_ref_types: bool = False,
  36. symint: bool = True,
  37. ) -> NamedCType:
  38. # This is a faux amis. If it makes sense in the future to add
  39. # more special cases here, or invert things so cpp.argument_type
  40. # calls this, or just completely inline the function, please do
  41. # it.
  42. return cpp.argumenttype_type(
  43. t,
  44. mutable=mutable,
  45. binds=binds,
  46. symint=symint,
  47. remove_non_owning_ref_types=remove_non_owning_ref_types,
  48. )
  49. def argument_type(
  50. a: Argument,
  51. *,
  52. binds: ArgName,
  53. remove_non_owning_ref_types: bool = False,
  54. symint: bool = True,
  55. ) -> NamedCType:
  56. return argumenttype_type(
  57. a.type,
  58. mutable=a.is_write,
  59. binds=binds,
  60. remove_non_owning_ref_types=remove_non_owning_ref_types,
  61. symint=symint,
  62. )
  63. def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
  64. # At present, there is no difference. But there could be!
  65. return cpp.returns_type(rs, symint=symint)
  66. def jit_arguments(func: FunctionSchema) -> List[Argument]:
  67. def to_argument(
  68. a: Union[Argument, TensorOptionsArguments, SelfArgument]
  69. ) -> List[Argument]:
  70. if isinstance(a, Argument):
  71. return [a]
  72. elif isinstance(a, SelfArgument):
  73. return [a.argument]
  74. elif isinstance(a, TensorOptionsArguments):
  75. return [a.dtype, a.layout, a.device, a.pin_memory]
  76. else:
  77. assert_never(a)
  78. return list(
  79. concatMap(
  80. to_argument,
  81. itertools.chain(
  82. func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
  83. ),
  84. )
  85. )
  86. def argument(
  87. a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
  88. ) -> Binding:
  89. return Binding(
  90. nctype=argument_type(
  91. a,
  92. binds=a.name,
  93. remove_non_owning_ref_types=remove_non_owning_ref_types,
  94. symint=symint,
  95. ),
  96. name=a.name,
  97. argument=a,
  98. )
  99. def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
  100. return [argument(a, symint=symint) for a in jit_arguments(func)]