prepare.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from typing import Optional, List
  2. import torch
  3. from torch.backends._nnapi.serializer import _NnapiSerializer
  4. ANEURALNETWORKS_PREFER_LOW_POWER = 0
  5. ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
  6. ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
  7. class NnapiModule(torch.nn.Module):
  8. """Torch Module that wraps an NNAPI Compilation.
  9. This module handles preparing the weights, initializing the
  10. NNAPI TorchBind object, and adjusting the memory formats
  11. of all inputs and outputs.
  12. """
  13. # _nnapi.Compilation is defined
  14. comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
  15. weights: List[torch.Tensor]
  16. out_templates: List[torch.Tensor]
  17. def __init__(
  18. self,
  19. shape_compute_module: torch.nn.Module,
  20. ser_model: torch.Tensor,
  21. weights: List[torch.Tensor],
  22. inp_mem_fmts: List[int],
  23. out_mem_fmts: List[int],
  24. compilation_preference: int,
  25. relax_f32_to_f16: bool,
  26. ):
  27. super().__init__()
  28. self.shape_compute_module = shape_compute_module
  29. self.ser_model = ser_model
  30. self.weights = weights
  31. self.inp_mem_fmts = inp_mem_fmts
  32. self.out_mem_fmts = out_mem_fmts
  33. self.out_templates = []
  34. self.comp = None
  35. self.compilation_preference = compilation_preference
  36. self.relax_f32_to_f16 = relax_f32_to_f16
  37. @torch.jit.export
  38. def init(self, args: List[torch.Tensor]):
  39. assert self.comp is None
  40. self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
  41. self.weights = [w.contiguous() for w in self.weights]
  42. comp = torch.classes._nnapi.Compilation()
  43. comp.init2(self.ser_model, self.weights, self.compilation_preference, self.relax_f32_to_f16)
  44. self.comp = comp
  45. def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
  46. if self.comp is None:
  47. self.init(args)
  48. comp = self.comp
  49. assert comp is not None
  50. outs = [torch.empty_like(out) for out in self.out_templates]
  51. assert len(args) == len(self.inp_mem_fmts)
  52. fixed_args = []
  53. for idx in range(len(args)):
  54. fmt = self.inp_mem_fmts[idx]
  55. # These constants match the values in DimOrder in serializer.py
  56. # TODO: See if it's possible to use those directly.
  57. if fmt == 0:
  58. fixed_args.append(args[idx].contiguous())
  59. elif fmt == 1:
  60. fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
  61. else:
  62. raise Exception("Invalid mem_fmt")
  63. comp.run(fixed_args, outs)
  64. assert len(outs) == len(self.out_mem_fmts)
  65. for idx in range(len(self.out_templates)):
  66. fmt = self.out_mem_fmts[idx]
  67. # These constants match the values in DimOrder in serializer.py
  68. # TODO: See if it's possible to use those directly.
  69. if fmt in (0, 2):
  70. pass
  71. elif fmt == 1:
  72. outs[idx] = outs[idx].permute(0, 3, 1, 2)
  73. else:
  74. raise Exception("Invalid mem_fmt")
  75. return outs
  76. def convert_model_to_nnapi(
  77. model,
  78. inputs,
  79. serializer=None,
  80. return_shapes=None,
  81. use_int16_for_qint16=False,
  82. compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
  83. relax_f32_to_f16=False,
  84. ):
  85. (shape_compute_module, ser_model_tensor, used_weights, inp_mem_fmts, out_mem_fmts,
  86. retval_count) = process_for_nnapi(model, inputs, serializer, return_shapes, use_int16_for_qint16)
  87. nnapi_model = NnapiModule(
  88. shape_compute_module,
  89. ser_model_tensor,
  90. used_weights,
  91. inp_mem_fmts,
  92. out_mem_fmts,
  93. compilation_preference,
  94. relax_f32_to_f16
  95. )
  96. class NnapiInterfaceWrapper(torch.nn.Module):
  97. """NNAPI list-ifying and de-list-ifying wrapper.
  98. NNAPI always expects a list of inputs and provides a list of outputs.
  99. This module allows us to accept inputs as separate arguments.
  100. It returns results as either a single tensor or tuple,
  101. matching the original module.
  102. """
  103. def __init__(self, mod):
  104. super().__init__()
  105. self.mod = mod
  106. wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
  107. wrapper_model = torch.jit.script(wrapper_model_py)
  108. # TODO: Maybe make these names match the original.
  109. arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
  110. if retval_count < 0:
  111. ret_expr = "retvals[0]"
  112. else:
  113. ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
  114. wrapper_model.define(
  115. f"def forward(self, {arg_list}):\n"
  116. f" retvals = self.mod([{arg_list}])\n"
  117. f" return {ret_expr}\n"
  118. )
  119. return wrapper_model
  120. def process_for_nnapi(model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False):
  121. model = torch.jit.freeze(model)
  122. if isinstance(inputs, torch.Tensor):
  123. inputs = [inputs]
  124. serializer = serializer or _NnapiSerializer(config=None, use_int16_for_qint16=use_int16_for_qint16)
  125. (ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines,
  126. retval_count) = serializer.serialize_model(model, inputs, return_shapes)
  127. ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
  128. # We have to create a new class here every time this function is called
  129. # because module.define adds a method to the *class*, not the instance.
  130. class ShapeComputeModule(torch.nn.Module):
  131. """Code-gen-ed module for tensor shape computation
  132. module.prepare will mutate ser_model according to the computed operand
  133. shapes, based on the shapes of args. Returns a list of output templates.
  134. """
  135. pass
  136. shape_compute_module = torch.jit.script(ShapeComputeModule())
  137. real_shape_compute_lines = [
  138. "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
  139. ] + [
  140. f" {line}\n" for line in shape_compute_lines
  141. ]
  142. shape_compute_module.define("".join(real_shape_compute_lines))
  143. return (
  144. shape_compute_module,
  145. ser_model_tensor,
  146. used_weights,
  147. inp_mem_fmts,
  148. out_mem_fmts,
  149. retval_count,
  150. )