123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- from typing import Optional, List
- import torch
- from torch.backends._nnapi.serializer import _NnapiSerializer
- ANEURALNETWORKS_PREFER_LOW_POWER = 0
- ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
- ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
- class NnapiModule(torch.nn.Module):
- """Torch Module that wraps an NNAPI Compilation.
- This module handles preparing the weights, initializing the
- NNAPI TorchBind object, and adjusting the memory formats
- of all inputs and outputs.
- """
- # _nnapi.Compilation is defined
- comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
- weights: List[torch.Tensor]
- out_templates: List[torch.Tensor]
- def __init__(
- self,
- shape_compute_module: torch.nn.Module,
- ser_model: torch.Tensor,
- weights: List[torch.Tensor],
- inp_mem_fmts: List[int],
- out_mem_fmts: List[int],
- compilation_preference: int,
- relax_f32_to_f16: bool,
- ):
- super().__init__()
- self.shape_compute_module = shape_compute_module
- self.ser_model = ser_model
- self.weights = weights
- self.inp_mem_fmts = inp_mem_fmts
- self.out_mem_fmts = out_mem_fmts
- self.out_templates = []
- self.comp = None
- self.compilation_preference = compilation_preference
- self.relax_f32_to_f16 = relax_f32_to_f16
- @torch.jit.export
- def init(self, args: List[torch.Tensor]):
- assert self.comp is None
- self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
- self.weights = [w.contiguous() for w in self.weights]
- comp = torch.classes._nnapi.Compilation()
- comp.init2(self.ser_model, self.weights, self.compilation_preference, self.relax_f32_to_f16)
- self.comp = comp
- def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
- if self.comp is None:
- self.init(args)
- comp = self.comp
- assert comp is not None
- outs = [torch.empty_like(out) for out in self.out_templates]
- assert len(args) == len(self.inp_mem_fmts)
- fixed_args = []
- for idx in range(len(args)):
- fmt = self.inp_mem_fmts[idx]
- # These constants match the values in DimOrder in serializer.py
- # TODO: See if it's possible to use those directly.
- if fmt == 0:
- fixed_args.append(args[idx].contiguous())
- elif fmt == 1:
- fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
- else:
- raise Exception("Invalid mem_fmt")
- comp.run(fixed_args, outs)
- assert len(outs) == len(self.out_mem_fmts)
- for idx in range(len(self.out_templates)):
- fmt = self.out_mem_fmts[idx]
- # These constants match the values in DimOrder in serializer.py
- # TODO: See if it's possible to use those directly.
- if fmt in (0, 2):
- pass
- elif fmt == 1:
- outs[idx] = outs[idx].permute(0, 3, 1, 2)
- else:
- raise Exception("Invalid mem_fmt")
- return outs
- def convert_model_to_nnapi(
- model,
- inputs,
- serializer=None,
- return_shapes=None,
- use_int16_for_qint16=False,
- compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
- relax_f32_to_f16=False,
- ):
- (shape_compute_module, ser_model_tensor, used_weights, inp_mem_fmts, out_mem_fmts,
- retval_count) = process_for_nnapi(model, inputs, serializer, return_shapes, use_int16_for_qint16)
- nnapi_model = NnapiModule(
- shape_compute_module,
- ser_model_tensor,
- used_weights,
- inp_mem_fmts,
- out_mem_fmts,
- compilation_preference,
- relax_f32_to_f16
- )
- class NnapiInterfaceWrapper(torch.nn.Module):
- """NNAPI list-ifying and de-list-ifying wrapper.
- NNAPI always expects a list of inputs and provides a list of outputs.
- This module allows us to accept inputs as separate arguments.
- It returns results as either a single tensor or tuple,
- matching the original module.
- """
- def __init__(self, mod):
- super().__init__()
- self.mod = mod
- wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
- wrapper_model = torch.jit.script(wrapper_model_py)
- # TODO: Maybe make these names match the original.
- arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
- if retval_count < 0:
- ret_expr = "retvals[0]"
- else:
- ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
- wrapper_model.define(
- f"def forward(self, {arg_list}):\n"
- f" retvals = self.mod([{arg_list}])\n"
- f" return {ret_expr}\n"
- )
- return wrapper_model
- def process_for_nnapi(model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False):
- model = torch.jit.freeze(model)
- if isinstance(inputs, torch.Tensor):
- inputs = [inputs]
- serializer = serializer or _NnapiSerializer(config=None, use_int16_for_qint16=use_int16_for_qint16)
- (ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines,
- retval_count) = serializer.serialize_model(model, inputs, return_shapes)
- ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
- # We have to create a new class here every time this function is called
- # because module.define adds a method to the *class*, not the instance.
- class ShapeComputeModule(torch.nn.Module):
- """Code-gen-ed module for tensor shape computation
- module.prepare will mutate ser_model according to the computed operand
- shapes, based on the shapes of args. Returns a list of output templates.
- """
- pass
- shape_compute_module = torch.jit.script(ShapeComputeModule())
- real_shape_compute_lines = [
- "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
- ] + [
- f" {line}\n" for line in shape_compute_lines
- ]
- shape_compute_module.define("".join(real_shape_compute_lines))
- return (
- shape_compute_module,
- ser_model_tensor,
- used_weights,
- inp_mem_fmts,
- out_mem_fmts,
- retval_count,
- )
|