serializer.py 78 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092
  1. import sys
  2. import enum
  3. import struct
  4. import array
  5. import logging
  6. import functools
  7. from typing import (
  8. Tuple,
  9. NamedTuple,
  10. List,
  11. Optional,
  12. )
  13. import torch
  14. # TODO: Add type annotations
  15. # TODO: Check tensor types for ops
  16. LOG = logging.getLogger("nnapi_serialize")
  17. class NNAPI_OperandCode:
  18. FLOAT32 = 0
  19. INT32 = 1
  20. UINT32 = 2
  21. TENSOR_FLOAT32 = 3
  22. TENSOR_INT32 = 4
  23. TENSOR_QUANT8_ASYMM = 5
  24. BOOL = 6
  25. TENSOR_QUANT16_SYMM = 7
  26. TENSOR_FLOAT16 = 8
  27. TENSOR_BOOL8 = 9
  28. FLOAT16 = 10
  29. TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
  30. TENSOR_QUANT16_ASYMM = 12
  31. class NNAPI_OperationCode:
  32. ADD = 0
  33. AVERAGE_POOL_2D = 1
  34. CONCATENATION = 2
  35. CONV_2D = 3
  36. DEPTHWISE_CONV_2D = 4
  37. DEPTH_TO_SPACE = 5
  38. DEQUANTIZE = 6
  39. EMBEDDING_LOOKUP = 7
  40. FLOOR = 8
  41. FULLY_CONNECTED = 9
  42. HASHTABLE_LOOKUP = 10
  43. L2_NORMALIZATION = 11
  44. L2_POOL_2D = 12
  45. LOCAL_RESPONSE_NORMALIZATION = 13
  46. LOGISTIC = 14
  47. LSH_PROJECTION = 15
  48. LSTM = 16
  49. MAX_POOL_2D = 17
  50. MUL = 18
  51. RELU = 19
  52. RELU1 = 20
  53. RELU6 = 21
  54. RESHAPE = 22
  55. RESIZE_BILINEAR = 23
  56. RNN = 24
  57. SOFTMAX = 25
  58. SPACE_TO_DEPTH = 26
  59. SVDF = 27
  60. TANH = 28
  61. BATCH_TO_SPACE_ND = 29
  62. DIV = 30
  63. MEAN = 31
  64. PAD = 32
  65. SPACE_TO_BATCH_ND = 33
  66. SQUEEZE = 34
  67. STRIDED_SLICE = 35
  68. SUB = 36
  69. TRANSPOSE = 37
  70. ABS = 38
  71. ARGMAX = 39
  72. ARGMIN = 40
  73. AXIS_ALIGNED_BBOX_TRANSFORM = 41
  74. BIDIRECTIONAL_SEQUENCE_LSTM = 42
  75. BIDIRECTIONAL_SEQUENCE_RNN = 43
  76. BOX_WITH_NMS_LIMIT = 44
  77. CAST = 45
  78. CHANNEL_SHUFFLE = 46
  79. DETECTION_POSTPROCESSING = 47
  80. EQUAL = 48
  81. EXP = 49
  82. EXPAND_DIMS = 50
  83. GATHER = 51
  84. GENERATE_PROPOSALS = 52
  85. GREATER = 53
  86. GREATER_EQUAL = 54
  87. GROUPED_CONV_2D = 55
  88. HEATMAP_MAX_KEYPOINT = 56
  89. INSTANCE_NORMALIZATION = 57
  90. LESS = 58
  91. LESS_EQUAL = 59
  92. LOG = 60
  93. LOGICAL_AND = 61
  94. LOGICAL_NOT = 62
  95. LOGICAL_OR = 63
  96. LOG_SOFTMAX = 64
  97. MAXIMUM = 65
  98. MINIMUM = 66
  99. NEG = 67
  100. NOT_EQUAL = 68
  101. PAD_V2 = 69
  102. POW = 70
  103. PRELU = 71
  104. QUANTIZE = 72
  105. QUANTIZED_16BIT_LSTM = 73
  106. RANDOM_MULTINOMIAL = 74
  107. REDUCE_ALL = 75
  108. REDUCE_ANY = 76
  109. REDUCE_MAX = 77
  110. REDUCE_MIN = 78
  111. REDUCE_PROD = 79
  112. REDUCE_SUM = 80
  113. ROI_ALIGN = 81
  114. ROI_POOLING = 82
  115. RSQRT = 83
  116. SELECT = 84
  117. SIN = 85
  118. SLICE = 86
  119. SPLIT = 87
  120. SQRT = 88
  121. TILE = 89
  122. TOPK_V2 = 90
  123. TRANSPOSE_CONV_2D = 91
  124. UNIDIRECTIONAL_SEQUENCE_LSTM = 92
  125. UNIDIRECTIONAL_SEQUENCE_RNN = 93
  126. RESIZE_NEAREST_NEIGHBOR = 94
  127. class NNAPI_FuseCode:
  128. FUSED_NONE = 0
  129. FUSED_RELU = 1
  130. FUSED_RELU1 = 2
  131. FUSED_RELU6 = 3
  132. class OperandValueSourceType:
  133. IMMEDIATE = 0
  134. NUMBERED_BUFFER = 2
  135. NUMBERED_MEMORY = 3
  136. # Scalar types that appear explicitly in models.
  137. # These must be kept in sync with
  138. # AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
  139. # TODO: Expose these directly to Python to avoid maintaining this list.
  140. class TorchScalarTypes(enum.Enum):
  141. QUINT8 = 13
  142. def approx_equal(lhs, rhs, tolerance=1e-6):
  143. return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
  144. def tensor_size(op_type, dims):
  145. ITEM_SIZES = {
  146. NNAPI_OperandCode.TENSOR_FLOAT32: 4,
  147. NNAPI_OperandCode.TENSOR_INT32: 4,
  148. NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
  149. NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
  150. NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
  151. }
  152. size = ITEM_SIZES[op_type]
  153. for d in dims:
  154. size *= d
  155. return size
  156. def change_element(tup, index, value):
  157. ls = list(tup)
  158. ls[index] = value
  159. return tuple(ls)
  160. class ConvPoolArgs2d(NamedTuple):
  161. """Configuration arguments for a convolution."""
  162. kernel_h: int
  163. kernel_w: int
  164. stride_h: int
  165. stride_w: int
  166. pad_t: int
  167. pad_b: int
  168. pad_l: int
  169. pad_r: int
  170. dilation_h: int
  171. dilation_w: int
  172. group: int
  173. class DimOrder(enum.Enum):
  174. PRESUMED_CONTIGUOUS = 0
  175. CHANNELS_LAST = 1
  176. SCALAR_OR_VECTOR = 2
  177. UNKNOWN_CONSTANT = 999
  178. class Operand(NamedTuple):
  179. """Represenation of an NNAPI operand."""
  180. # NNAPI operand type. One of NNAPI_OperandCode.
  181. # TODO: Make this an enum.
  182. op_type: int
  183. # This is always the PyTorch shape, which is NCHW for feature maps.
  184. # The actual NNAPI operand might have a transposed shape.
  185. # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
  186. shape: Tuple[int, ...]
  187. # Specifies how the shape of the operand that we define in NNAPI
  188. # relates to the shape we track above.
  189. # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
  190. # the shape of the PyTorch tensor.
  191. # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
  192. # the NNAPI operand will be represented explicitly as NHWC.
  193. dim_order: DimOrder
  194. # Quantization params
  195. scale: float
  196. zero_point: int
  197. def use_nchw(self):
  198. if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
  199. return True
  200. if self.dim_order is DimOrder.CHANNELS_LAST:
  201. return False
  202. raise Exception("Unknown dim order")
  203. def broadcast_shapes(shape1, shape2):
  204. assert len(shape1) > 0
  205. assert len(shape2) > 0
  206. s1 = list(shape1)
  207. s2 = list(shape2)
  208. # TODO: Support non-equal-rank broadcast where semantics match.
  209. # This can be tricky for NHWC tensors because dimension orders
  210. # don't match between PT and NNAPI, even though semantics match.
  211. if len(s1) > len(s2):
  212. # s2 = [1] * (len(s1) - len(s2)) + s2
  213. raise Exception("Non-equal-rank broadcast is not supported yet.")
  214. if len(s2) > len(s1):
  215. # s3 = [1] * (len(s2) - len(s1)) + s1
  216. raise Exception("Non-equal-rank broadcast is not supported yet.")
  217. ret = []
  218. for d1, d2 in zip(s1, s2):
  219. if d1 == 1:
  220. ret.append(d2)
  221. elif d2 == 1:
  222. ret.append(d1)
  223. elif d1 == d2:
  224. ret.append(d1)
  225. else:
  226. raise Exception("Cannot broadcast shapes: {} and {}".format(shape1, shape2))
  227. return tuple(ret)
  228. def get_conv_pool_shape(image_shape, args, out_ch, transpose):
  229. batch, in_c, in_h, in_w = image_shape
  230. # TODO: Handle dilation
  231. if args.dilation_h != 1 or args.dilation_w != 1:
  232. raise Exception("Dilation not supported yet.")
  233. if transpose:
  234. out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
  235. out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
  236. else:
  237. out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
  238. out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
  239. # Handle variable-sized tensors.
  240. if in_h == 0:
  241. out_h = 0
  242. if in_w == 0:
  243. out_w = 0
  244. out_shape = (batch, out_ch, out_h, out_w)
  245. return out_shape
  246. def fix_shape(shape, dim_order):
  247. # Return the actual shape that an operand should have in NNAPI,
  248. # given a PyTorch shape and dimension order. This is where we
  249. # convert from PyTorch's "always NCHW" shape to explicit NHWC.
  250. if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
  251. return shape
  252. if dim_order is DimOrder.CHANNELS_LAST:
  253. return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
  254. if dim_order is DimOrder.SCALAR_OR_VECTOR:
  255. assert len(shape) == 0 or len(shape) == 1
  256. return shape
  257. if dim_order is DimOrder.UNKNOWN_CONSTANT:
  258. # XXX think this through
  259. return shape
  260. raise Exception(f"Bad dim_order: {dim_order!r}.")
  261. def reverse_map_dim(dim_order, d):
  262. # Return the original PyTorch dimension position for a given dimension.
  263. # d should be the dimension that NNAPI will see.
  264. # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
  265. # reverse_map_dim(CHANNELS_LAST, 3) == 1
  266. if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
  267. return d
  268. assert dim_order is DimOrder.CHANNELS_LAST
  269. return [0, 2, 3, 1][d]
  270. def flex_name(op_id, dim):
  271. # Return the local variable name for the computed flexible size
  272. # for a given op and dimension.
  273. return f"s_{op_id}_{dim}"
  274. class _NnapiSerializer:
  275. def __init__(self, config, use_int16_for_qint16=False):
  276. self.operands = []
  277. self.values = []
  278. self.operations = []
  279. self.value_data = []
  280. self.operation_args = []
  281. self.inputs = []
  282. self.outputs = []
  283. self.flexible_shape_computation_lines = []
  284. self.modules = {}
  285. self.constants = {}
  286. self.tensor_sequences = {}
  287. self.jitval_operand_map = {}
  288. self.cached_immediates = {}
  289. self.used_weights = []
  290. self.weight_offset = 0
  291. self.use_int16_for_qint16 = use_int16_for_qint16
  292. if config is None:
  293. config = {}
  294. def get_next_operand_id(self):
  295. return len(self.operands)
  296. # Add a tensor operand corresponding to a JIT Value.
  297. # Returns the NNAPI operand ID. Can be looked up later with
  298. # get_tensor_operand_by_jitval.
  299. def add_tensor_operand(self, jitval, oper):
  300. assert isinstance(oper, Operand)
  301. if jitval in self.jitval_operand_map:
  302. raise Exception("Duplicate tensor: %r" % jitval)
  303. operand_id = self.get_next_operand_id()
  304. self.operands.append(oper)
  305. self.jitval_operand_map[jitval] = operand_id
  306. return operand_id
  307. # Add a tensor operand that does not correspond to a JIT Value.
  308. # Useful for cases where multiple NNAPI operands are required
  309. # to implement one JIT IR node. Returns the NNAPI operand ID.
  310. def add_anonymous_tensor_operand(self, oper):
  311. assert isinstance(oper, Operand)
  312. operand_id = self.get_next_operand_id()
  313. self.operands.append(oper)
  314. return operand_id
  315. def torch_tensor_to_operand(self, tensor, dim_order):
  316. dtype = str(tensor.dtype).replace("torch.", "")
  317. scale = 0.0
  318. zero_point = 0
  319. if dtype == "float32":
  320. op_type = NNAPI_OperandCode.TENSOR_FLOAT32
  321. elif dtype == "int32":
  322. op_type = NNAPI_OperandCode.TENSOR_INT32
  323. elif dtype == "quint8":
  324. op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
  325. scale = tensor.q_scale()
  326. zero_point = tensor.q_zero_point()
  327. elif dtype == "qint32":
  328. op_type = NNAPI_OperandCode.TENSOR_INT32
  329. scale = tensor.q_scale()
  330. zero_point = tensor.q_zero_point()
  331. assert zero_point == 0
  332. elif dtype == "int16":
  333. if self.use_int16_for_qint16:
  334. nnapi_dtype = getattr(tensor, "nnapi_dtype", None)
  335. op_codes = (NNAPI_OperandCode.TENSOR_QUANT16_SYMM, NNAPI_OperandCode.TENSOR_QUANT16_ASYMM)
  336. if nnapi_dtype in op_codes:
  337. op_type = nnapi_dtype
  338. scale = tensor.nnapi_scale
  339. zero_point = tensor.nnapi_zero_point
  340. else:
  341. raise Exception(f"`nnapi_type` needs to be one of {op_codes} for `int16`")
  342. else:
  343. raise Exception(
  344. "`int16` isn't supported. If you're trying to represent NNAPI"
  345. " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`")
  346. else:
  347. raise Exception(f"Can't handle input with dtype '{tensor.dtype}'")
  348. return Operand(
  349. shape=tuple(tensor.shape),
  350. op_type=op_type,
  351. dim_order=dim_order,
  352. scale=scale,
  353. zero_point=zero_point,
  354. )
  355. def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
  356. dim_order = (
  357. DimOrder.CHANNELS_LAST if getattr(tensor, "nnapi_nhwc", False)
  358. else DimOrder.PRESUMED_CONTIGUOUS)
  359. toper = self.torch_tensor_to_operand(tensor, dim_order)
  360. operand_id = self.add_tensor_operand(jitval, toper)
  361. self.inputs.append(operand_id)
  362. for dim, size in enumerate(tensor.shape):
  363. if size == 0:
  364. self.compute_operand_shape(operand_id, dim, f"args[{arg_idx}].shape[{dim}]")
  365. return operand_id
  366. def add_tensor_operand_for_weight(self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT):
  367. toper = self.torch_tensor_to_operand(tensor, dim_order)
  368. operand_id = len(self.operands)
  369. self.operands.append(toper)
  370. tsize = tensor_size(toper.op_type, toper.shape)
  371. psize = ((tsize - 1) | 0x3) + 1
  372. self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
  373. buf_num = len(self.used_weights)
  374. offset = 0
  375. self.value_data.append(struct.pack(
  376. "iii",
  377. buf_num,
  378. offset,
  379. tsize))
  380. # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor
  381. if dim_order == DimOrder.CHANNELS_LAST:
  382. tensor = tensor.permute(0, 2, 3, 1)
  383. self.used_weights.append(tensor)
  384. return operand_id
  385. def add_immediate_operand(self, code, value, dims):
  386. assert isinstance(dims, tuple)
  387. cache_key = (code, value)
  388. if cache_key not in self.cached_immediates:
  389. operand_id = len(self.operands)
  390. self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
  391. self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
  392. self.value_data.append(value)
  393. self.cached_immediates[cache_key] = operand_id
  394. return self.cached_immediates[cache_key]
  395. def add_immediate_int_scalar(self, value):
  396. return self.add_immediate_operand(
  397. NNAPI_OperandCode.INT32,
  398. struct.pack("i", value),
  399. ())
  400. def add_immediate_float_scalar(self, value):
  401. return self.add_immediate_operand(
  402. NNAPI_OperandCode.FLOAT32,
  403. struct.pack("f", value),
  404. ())
  405. def add_immediate_bool_scalar(self, value):
  406. return self.add_immediate_operand(
  407. NNAPI_OperandCode.BOOL,
  408. b"\x01" if value else b"\x00",
  409. ())
  410. def add_immediate_int_vector(self, value):
  411. return self.add_immediate_operand(
  412. NNAPI_OperandCode.TENSOR_INT32,
  413. array.array("i", value).tobytes(),
  414. (len(value),))
  415. def has_operand_for_jitval(self, jitval):
  416. return jitval in self.jitval_operand_map
  417. def get_tensor_operand_by_jitval(self, jitval):
  418. operand_id = self.jitval_operand_map[jitval]
  419. return (operand_id, self.operands[operand_id])
  420. def get_tensor_operand_by_jitval_fixed_size(self, jitval):
  421. op_id, oper = self.get_tensor_operand_by_jitval(jitval)
  422. for s in oper.shape:
  423. if s == 0:
  424. # TODO: Improve this error message, possibly after converting
  425. # many callsites to support flexible size.
  426. raise Exception("Flexible size is not supported for this operand.")
  427. if s < 0:
  428. # runtime flex
  429. LOG.warn(f"Operand {oper} has runtime flex shape")
  430. return op_id, oper
  431. def get_tensor_operand_or_constant(self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS):
  432. operand_id = self.jitval_operand_map.get(jitval)
  433. if operand_id is None:
  434. _, value = self.get_constant_value(jitval, "TensorType")
  435. operand_id = self.add_tensor_operand_for_weight(value, dim_order)
  436. return (operand_id, self.operands[operand_id])
  437. def get_tensor_operand_for_weight(self, jitval):
  438. _, value = self.get_constant_value(jitval, "TensorType")
  439. operand_id = self.add_tensor_operand_for_weight(value)
  440. return (operand_id, self.operands[operand_id])
  441. def add_operation(self, opcode, inputs, outputs):
  442. self.operations.append((opcode, len(inputs), len(outputs)))
  443. self.operation_args.extend(inputs + outputs)
  444. def add_tensor_sequence(self, jitval, values):
  445. assert jitval not in self.tensor_sequences
  446. self.tensor_sequences[jitval] = values
  447. def add_constant_value(self, jitval, ctype, value):
  448. assert jitval not in self.constants
  449. self.constants[jitval] = (ctype, value)
  450. def get_constant_value(self, jitval, typekind=None):
  451. record = self.constants.get(jitval)
  452. if record is None:
  453. raise Exception(f"Could not find constant value for '{jitval!r}'.")
  454. ctype, _ = record
  455. if typekind is not None and ctype.kind() != typekind:
  456. raise Exception(
  457. f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'")
  458. return record
  459. def operand_to_template_torchscript(self, op_id, oper, shape=None):
  460. """Return a TorchScript expression to build a template for a given operand."""
  461. if shape is None:
  462. shape = oper.shape
  463. else:
  464. assert len(shape) == len(oper.shape)
  465. shape_parts = ["("]
  466. for d, s in enumerate(shape):
  467. if s > 0:
  468. # Fixed shape dimension: just add the value.
  469. shape_parts.append(str(s))
  470. elif s == 0:
  471. # Load time flexible shape dimension: it should have been computed in a variable.
  472. shape_parts.append(flex_name(op_id, d))
  473. elif s == -1:
  474. # Runtime flexible shape
  475. shape_parts.append('0')
  476. else:
  477. raise Exception("Unknown dim value, dimensions should be >= -1")
  478. shape_parts.append(",")
  479. shape_parts.append(")")
  480. shape_code = "".join(shape_parts)
  481. if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
  482. return f"torch.zeros({shape_code}, dtype=torch.float32)"
  483. elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32:
  484. return f"torch.zeros({shape_code}, dtype=torch.int32)"
  485. elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
  486. return (
  487. f"torch.quantize_per_tensor("
  488. f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
  489. f".expand({shape_code}).contiguous()"
  490. )
  491. elif oper.op_type in (NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, NNAPI_OperandCode.TENSOR_QUANT16_SYMM):
  492. if self.use_int16_for_qint16:
  493. return f"torch.zeros({shape_code}, dtype=torch.int16)"
  494. else:
  495. raise Exception(
  496. "`int16` isn't supported. If you're trying to represent NNAPI"
  497. " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`")
  498. raise Exception(f"Unsupported output operand type: {oper.op_type}")
  499. def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
  500. self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
  501. def compute_operand_shape(self, op_id, dim, expr):
  502. self.flexible_shape_computation_lines.append(f"{flex_name(op_id, dim)} = {expr}")
  503. def transpose_to_nhwc(self, in_id, oper):
  504. if oper.shape[2:] != (1, 1):
  505. raise Exception("Automatic transpose only supported for H,W == 1,1")
  506. out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
  507. inputs = [None] * 2
  508. inputs[0] = in_id
  509. inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
  510. outputs = [None] * 1
  511. outputs[0] = self.add_anonymous_tensor_operand(out_oper)
  512. self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
  513. return outputs[0], out_oper
  514. # Transpose inputs as necessary to allow broadcasting.
  515. def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
  516. if in0_oper.dim_order == in1_oper.dim_order:
  517. return in0_id, in0_oper, in1_id, in1_oper
  518. # Assume NHWC is preferred if there is a mismatch.
  519. orders = (in0_oper.dim_order, in1_oper.dim_order)
  520. if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
  521. return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
  522. if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
  523. return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
  524. raise Exception(
  525. "Automatic transpose not supported for dim_orders: %r, %r" %
  526. (in0_oper.dim_order, in1_oper.dim_order))
  527. def get_size_arg(self, jitval):
  528. ctype, value = self.get_constant_value(jitval)
  529. if ctype.kind() == "ListType":
  530. assert ctype.getElementType().kind() == "IntType"
  531. return value
  532. raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'")
  533. def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
  534. pc = [i.item() for i in packed_config]
  535. assert pc[0] == 2
  536. strides = [pc[1], pc[2]]
  537. paddings = [pc[3], pc[4]]
  538. dilations = [pc[5], pc[6]]
  539. output_padding = [pc[7], pc[8]]
  540. group_num = pc[9]
  541. assert len(pc) == 11
  542. assert output_padding == [0, 0]
  543. return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num)
  544. def get_conv_pool_args_2d_from_jit(self, kernel_size, stride, padding, dilation=None, group=None):
  545. strides = self.get_size_arg(stride)
  546. paddings = self.get_size_arg(padding)
  547. if dilation is None:
  548. dilations = [1, 1]
  549. else:
  550. dilations = self.get_size_arg(dilation)
  551. if group is not None:
  552. _, group_num = self.get_constant_value(group, "IntType")
  553. else:
  554. group_num = None
  555. return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num)
  556. def get_conv_pool_args_2d_common(self, kernel_size, strides, paddings, dilations, group_num):
  557. kernels = list(kernel_size)
  558. assert len(kernels) == 2
  559. assert len(strides) == 2
  560. assert len(paddings) == 2
  561. assert len(dilations) == 2
  562. # NNAPI uses 4 values for padding.
  563. ph, pw = paddings
  564. real_paddings = [ph, ph, pw, pw]
  565. return ConvPoolArgs2d(*(kernels + strides + real_paddings + dilations + [group_num]))
  566. def serialize_model(self, model, inputs, return_shapes=None):
  567. self.add_immediate_bool_scalar(False)
  568. self.add_immediate_bool_scalar(True)
  569. inp_dim_orders = []
  570. out_dim_orders = []
  571. self_jitval = next(model.graph.inputs())
  572. self.add_constant_value(self_jitval, self_jitval.type(), model)
  573. for arg_idx, (input_value, input_tensor) in enumerate(zip(list(model.graph.inputs())[1:], inputs)):
  574. op_id = self.add_tensor_operand_for_input(arg_idx, input_value, input_tensor)
  575. inp_dim_orders.append(self.operands[op_id].dim_order.value)
  576. for idx, node in enumerate(model.graph.nodes()):
  577. LOG.debug("Processing node #%d: %r", idx, node)
  578. self.add_node(node)
  579. retn = model.graph.return_node()
  580. assert retn.inputsSize() == 1
  581. assert retn.outputsSize() == 0
  582. retn_input = retn.inputsAt(0)
  583. template_return_lines = ["return ["]
  584. if retn_input.type().kind() == "TensorType":
  585. return_values = [retn_input]
  586. retval_count = -1
  587. elif retn_input.type().kind() == "TupleType":
  588. return_values = self.tensor_sequences[retn_input]
  589. retval_count = len(return_values)
  590. else:
  591. raise Exception(f"Unsupported return type: {retn_input.type()}")
  592. if return_shapes is not None:
  593. assert len(return_shapes) == len(return_values)
  594. for i, v in enumerate(return_values):
  595. op_id = self.jitval_operand_map[v]
  596. self.outputs.append(op_id)
  597. out_dim_orders.append(self.operands[op_id].dim_order.value)
  598. shape = return_shapes[i] if return_shapes else None
  599. template_return_lines.append(
  600. self.operand_to_template_torchscript(
  601. op_id, self.operands[op_id], shape) + ","
  602. )
  603. template_return_lines.append("]")
  604. model = []
  605. version = 1
  606. header = struct.pack(
  607. "iiiiii",
  608. version,
  609. len(self.operands),
  610. len(self.values),
  611. len(self.operations),
  612. len(self.inputs),
  613. len(self.outputs),
  614. )
  615. model.append(header)
  616. serialized_values, serialized_value_data = self.serialize_values()
  617. model.extend(struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands)
  618. model.extend(serialized_values)
  619. model.extend(struct.pack("iii", *x) for x in self.operations)
  620. # Compact the model so we can get its length so far.
  621. model = [b"".join(model)]
  622. model_offset = len(model[0])
  623. # Model offset is the index into the model (in 32-bit words, not bytes)
  624. # of the next dimension we're about to serialize. If it's 0,
  625. # generate code to mutate it before passing to NNAPI.
  626. assert model_offset % 4 == 0
  627. model_offset = int(model_offset / 4)
  628. for (op_id, (_, dims, dim_order, _, _)) in enumerate(self.operands):
  629. shape = fix_shape(dims, dim_order)
  630. for d, s in enumerate(shape):
  631. if s == 0:
  632. pt_d = reverse_map_dim(dim_order, d)
  633. self.flexible_shape_computation_lines.append(
  634. f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}")
  635. model_offset += 1
  636. # convert runtime flex shape from -1 to 0
  637. shape = tuple(d if d != -1 else 0 for d in shape)
  638. model.append(self.serialize_ints(shape))
  639. model.extend(serialized_value_data)
  640. model.append(self.serialize_ints(self.operation_args))
  641. model.append(self.serialize_ints(self.inputs))
  642. model.append(self.serialize_ints(self.outputs))
  643. self.flexible_shape_computation_lines.extend(template_return_lines)
  644. return (
  645. array.array("i", b"".join(model)),
  646. self.used_weights,
  647. inp_dim_orders,
  648. out_dim_orders,
  649. self.flexible_shape_computation_lines,
  650. retval_count,
  651. )
  652. def serialize_values(self):
  653. serialized_values = []
  654. serialized_value_data = []
  655. assert len(self.values) == len(self.value_data)
  656. for ((op_index, source_type), data) in zip(self.values, self.value_data):
  657. source_length = len(data)
  658. # Pad with 0 bytes out to a multiple of 4 for alignment.
  659. physical_length = ((source_length - 1) | 0x3) + 1
  660. padded_data = data + (b"\0" * (physical_length - source_length))
  661. serialized_values.append(struct.pack("iii", op_index, source_type, source_length))
  662. serialized_value_data.append(padded_data)
  663. return serialized_values, serialized_value_data
  664. @staticmethod
  665. def serialize_ints(ints):
  666. return array.array("i", ints).tobytes()
  667. ADDER_MAP = {
  668. "prim::GetAttr": lambda self, node:
  669. self.add_getattr(node),
  670. "prim::Constant": lambda self, node:
  671. self.add_constant_node(node),
  672. "prim::ListConstruct": lambda self, node:
  673. self.add_list_construct(node),
  674. "prim::TupleConstruct": lambda self, node:
  675. self.add_tuple_construct(node),
  676. "aten::unsqueeze": lambda self, node:
  677. self.add_unsqueeze(node),
  678. "aten::to": lambda self, node:
  679. self.add_to(node),
  680. "aten::detach": lambda self, node:
  681. self._identity(node),
  682. "aten::reshape": lambda self, node:
  683. self.add_reshape(node),
  684. "aten::flatten": lambda self, node:
  685. self.add_flatten(node),
  686. "aten::slice": lambda self, node:
  687. self.add_slice(node),
  688. "aten::size": lambda self, node:
  689. self.add_size(node),
  690. "aten::cat": lambda self, node:
  691. self.add_cat(node),
  692. "aten::mean": lambda self, node:
  693. self.add_mean(node),
  694. "aten::quantize_per_tensor": lambda self, node:
  695. self.add_quantize(node),
  696. "aten::dequantize": lambda self, node:
  697. self.add_dequantize(node),
  698. "aten::add": lambda self, node:
  699. self.add_add_sub_op(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE),
  700. "aten::sub": lambda self, node:
  701. self.add_add_sub_op(node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE),
  702. "aten::mul": lambda self, node:
  703. self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE),
  704. "aten::div": lambda self, node:
  705. self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE),
  706. "aten::relu": lambda self, node:
  707. self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.RELU),
  708. "aten::sigmoid": lambda self, node:
  709. self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.LOGISTIC),
  710. "aten::softmax": lambda self, node:
  711. self.add_softmax(node),
  712. "aten::hardtanh": lambda self, node:
  713. self.add_hardtanh(node),
  714. "aten::avg_pool2d": lambda self, node:
  715. self.add_avg_pool2d(node),
  716. "aten::max_pool2d": lambda self, node:
  717. self.add_pool2d_node(node, NNAPI_OperationCode.MAX_POOL_2D),
  718. "aten::adaptive_avg_pool2d": lambda self, node:
  719. self.add_adaptive_avg_pool2d(node),
  720. "aten::upsample_nearest2d": lambda self, node:
  721. self.add_upsample_nearest2d(node),
  722. "aten::prelu": lambda self, node:
  723. self.add_prelu_op(node),
  724. "aten::addmm": lambda self, node:
  725. self.add_addmm(node),
  726. "aten::linear": lambda self, node:
  727. self.add_linear(node),
  728. "aten::_convolution": lambda self, node:
  729. self.add_conv_underscore(node),
  730. "aten::conv2d": lambda self, node:
  731. self.add_conv2d(node),
  732. "aten::log_softmax": lambda self, node:
  733. self.add_log_softmax(node),
  734. "quantized::linear": lambda self, node:
  735. self.add_qlinear(node),
  736. "quantized::conv2d": lambda self, node:
  737. self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE),
  738. "quantized::conv2d_relu": lambda self, node:
  739. self.add_qconv2d(node, NNAPI_FuseCode.FUSED_RELU),
  740. "quantized::conv_transpose2d": lambda self, node:
  741. self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE, transpose=True),
  742. "quantized::add": lambda self, node:
  743. self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE),
  744. "quantized::add_relu": lambda self, node:
  745. self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU),
  746. "quantized::mul": lambda self, node:
  747. self.add_qadd(node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE),
  748. }
  749. def add_node(self, node):
  750. adder = self.ADDER_MAP.get(node.kind())
  751. if not adder:
  752. raise Exception("Unsupported node kind (%r) in node %r" % (node.kind(), node))
  753. adder(self, node)
  754. def _identity(self, node):
  755. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  756. jitval = node.outputsAt(0)
  757. self.jitval_operand_map[jitval] = in_id
  758. def add_getattr(self, node):
  759. assert node.inputsSize() == 1
  760. assert node.outputsSize() == 1
  761. obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
  762. assert str(obj_ctype).startswith("__torch__.")
  763. name = node.s("name")
  764. value = getattr(obj, name)
  765. output = node.outputsAt(0)
  766. ctype = output.type()
  767. self.add_constant_value(output, ctype, value)
  768. def add_constant_node(self, node):
  769. assert node.inputsSize() == 0
  770. assert node.outputsSize() == 1
  771. output = node.outputsAt(0)
  772. ctype = output.type()
  773. value = output.toIValue()
  774. self.add_constant_value(output, ctype, value)
  775. def add_list_construct(self, node):
  776. assert node.outputsSize() == 1
  777. output = node.outputsAt(0)
  778. ctype = output.type()
  779. const_vals: Optional[List] = []
  780. tensors: Optional[List] = []
  781. for inp in node.inputs():
  782. if const_vals is not None and inp in self.constants:
  783. _, val = self.get_constant_value(inp)
  784. const_vals.append(val)
  785. else:
  786. const_vals = None
  787. if tensors is not None and inp.type().kind() == "TensorType":
  788. tensors.append(inp)
  789. else:
  790. tensors = None
  791. if const_vals is not None:
  792. # NOTE: Now that TorchScript supports list constants,
  793. # this code path might not be used anymore.
  794. self.add_constant_value(output, ctype, const_vals)
  795. if tensors is not None:
  796. self.add_tensor_sequence(output, tensors)
  797. if const_vals is None and tensors is None:
  798. raise Exception(
  799. "Unable to handle ListConstruct node."
  800. " Neither all constants nor all tensors. %r" % node)
  801. def add_tuple_construct(self, node):
  802. assert node.outputsSize() == 1
  803. output = node.outputsAt(0)
  804. values = []
  805. for inp in node.inputs():
  806. values.append(inp)
  807. self.add_tensor_sequence(output, values)
  808. def add_unsqueeze(self, node):
  809. assert node.inputsSize() == 2
  810. assert node.outputsSize() == 1
  811. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  812. _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
  813. assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
  814. real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
  815. out_shape_list = list(in_oper.shape)
  816. out_shape_list.insert(real_dim, 1)
  817. out_shape = tuple(out_shape_list)
  818. out_oper = in_oper._replace(shape=out_shape)
  819. inputs = [None] * 2
  820. inputs[0] = in_id
  821. inputs[1] = self.add_immediate_int_scalar(dim)
  822. outputs = [None] * 1
  823. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  824. self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
  825. def add_to(self, node):
  826. # Handle to("cpu") / to("gpu") case
  827. self._identity(node)
  828. def add_reshape(self, node):
  829. assert node.inputsSize() == 2
  830. assert node.outputsSize() == 1
  831. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  832. shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
  833. assert shape_ctype.kind() == "ListType"
  834. assert shape_ctype.getElementType().kind() == "IntType"
  835. is_trivial_reshape = len(shape) == 2 and shape[1] == -1
  836. if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
  837. raise Exception(
  838. "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1].")
  839. # Bit of a hack here. Use a real tensor to infer the output shape.
  840. out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
  841. out_oper = in_oper._replace(shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS)
  842. inputs = [None] * 2
  843. inputs[0] = in_id
  844. inputs[1] = self.add_immediate_int_vector(shape)
  845. outputs = [None] * 1
  846. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  847. self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
  848. def add_flatten(self, node):
  849. assert node.inputsSize() == 3
  850. assert node.outputsSize() == 1
  851. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  852. start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
  853. end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
  854. # channels last with channels == 1 or (height & width both 1)
  855. is_trivial_flatten = len(in_oper.shape) == 4 and (
  856. in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1))
  857. if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten:
  858. raise Exception(
  859. "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1")
  860. if start_dim < 0:
  861. start_dim += len(in_oper.shape)
  862. if end_dim < 0:
  863. end_dim += len(in_oper.shape)
  864. out_shape = (
  865. in_oper.shape[: start_dim] +
  866. (functools.reduce(
  867. lambda x, y: x * y, in_oper.shape[start_dim: end_dim + 1]),) +
  868. in_oper.shape[end_dim + 1:]
  869. )
  870. if any(dim == 0 for dim in in_oper.shape[start_dim: end_dim + 1]):
  871. raise Exception("Flattening flexible dims is not supported yet")
  872. non_flattened_dims = in_oper.shape[: start_dim] + in_oper.shape[end_dim + 1:]
  873. if non_flattened_dims.count(0) > 1:
  874. raise Exception("Only 1 dim can be flexible")
  875. out_oper = in_oper._replace(shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS)
  876. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  877. for idx, dim in enumerate(out_shape):
  878. if dim == 0:
  879. self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
  880. inputs_1 = tuple(
  881. dim if dim != 0 else -1
  882. for dim in out_shape
  883. )
  884. inputs = [None] * 2
  885. inputs[0] = in_id
  886. inputs[1] = self.add_immediate_int_vector(inputs_1)
  887. outputs = [None] * 1
  888. outputs[0] = out_id
  889. self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
  890. def add_slice(self, node):
  891. assert node.inputsSize() == 5
  892. assert node.outputsSize() == 1
  893. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  894. _, dim_value = self.get_constant_value(node.inputsAt(1))
  895. _, start_value = self.get_constant_value(node.inputsAt(2))
  896. _, stop_value = self.get_constant_value(node.inputsAt(3))
  897. _, step_value = self.get_constant_value(node.inputsAt(4))
  898. if start_value is None:
  899. start_value = 0
  900. if stop_value is None:
  901. stop_value = sys.maxsize
  902. if start_value < 0:
  903. start_value += in_oper.shape[dim_value]
  904. elif start_value == sys.maxsize:
  905. start_value = 0
  906. if start_value == 0 and stop_value == sys.maxsize:
  907. self._identity(node)
  908. return
  909. if in_oper.shape[dim_value] == 0:
  910. raise Exception("Unable to slice with flexible shape")
  911. if stop_value < 0:
  912. stop_value += in_oper.shape[dim_value]
  913. elif stop_value == sys.maxsize:
  914. stop_value = in_oper.shape[dim_value]
  915. if start_value >= stop_value:
  916. raise Exception("Slice start value should be less than stop value")
  917. out_len = (stop_value - start_value) // step_value
  918. out_shape = tuple(out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape))
  919. out_id = self.add_tensor_operand(node.outputsAt(0), in_oper._replace(shape=out_shape))
  920. # flex inputs
  921. end_mask = 0
  922. for idx, dim in enumerate(out_shape):
  923. if dim == 0:
  924. self.forward_operand_shape(out_id, idx, in_id, idx)
  925. end_mask |= (1 << idx)
  926. inputs = [None] * 7
  927. inputs[0] = in_id
  928. inputs[1] = self.add_immediate_int_vector(
  929. [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))])
  930. inputs[2] = self.add_immediate_int_vector(
  931. [stop_value if i == dim_value else dim for i, dim in enumerate(in_oper.shape)])
  932. inputs[3] = self.add_immediate_int_vector(
  933. [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))])
  934. inputs[4] = self.add_immediate_int_scalar(0) # begin mask
  935. inputs[5] = self.add_immediate_int_scalar(end_mask)
  936. inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas
  937. outputs = [None] * 1
  938. outputs[0] = out_id
  939. self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs)
  940. def add_size(self, node):
  941. assert node.inputsSize() == 2
  942. assert node.outputsSize() == 1
  943. _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  944. _, value = self.constants[node.inputsAt(1)]
  945. res = in_oper.shape[value]
  946. output = node.outputsAt(0)
  947. self.add_constant_value(output, output.type(), res)
  948. def add_cat(self, node):
  949. assert node.inputsSize() == 2
  950. assert node.outputsSize() == 1
  951. tensors = self.tensor_sequences[node.inputsAt(0)]
  952. _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
  953. assert len(tensors) > 0
  954. in_ids = []
  955. out_oper = None
  956. out_dim_size = 0
  957. for inp in tensors:
  958. in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
  959. if out_oper is None:
  960. out_shape = change_element(in_oper.shape, dim, -1)
  961. out_oper = in_oper._replace(shape=out_shape)
  962. assert in_oper.op_type == out_oper.op_type
  963. assert in_oper.dim_order == out_oper.dim_order
  964. assert change_element(in_oper.shape, dim, -1) == change_element(out_oper.shape, dim, -1)
  965. # TODO: Possibly check scale and zero point.
  966. in_ids.append(in_id)
  967. # TODO: Possibly support variable-sized inputs.
  968. out_dim_size += in_oper.shape[dim]
  969. assert out_oper is not None
  970. out_oper = out_oper._replace(shape=change_element(out_oper.shape, dim, out_dim_size))
  971. if in_oper.dim_order == DimOrder.CHANNELS_LAST:
  972. assert len(out_oper.shape) == 4
  973. nnapi_dim = [0, 3, 1, 2][dim]
  974. else:
  975. nnapi_dim = dim
  976. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  977. for idx, d in enumerate(out_oper.shape):
  978. if d == 0:
  979. if idx == dim:
  980. shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
  981. self.compute_operand_shape(out_id, idx, shape)
  982. else:
  983. self.forward_operand_shape(out_id, idx, in_ids[0], idx)
  984. inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
  985. outputs = [None] * 1
  986. outputs[0] = out_id
  987. self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
  988. def add_mean(self, node):
  989. assert node.inputsSize() == 4
  990. assert node.outputsSize() == 1
  991. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  992. dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
  993. assert dim_ctype.kind() == "ListType"
  994. assert dim_ctype.getElementType().kind() == "IntType"
  995. _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
  996. # Expect None for dtype
  997. self.get_constant_value(node.inputsAt(3), "NoneType")
  998. if in_oper.dim_order == DimOrder.CHANNELS_LAST:
  999. assert len(in_oper.shape) == 4
  1000. nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
  1001. else:
  1002. nnapi_dim = dim
  1003. collapsed_dims = set()
  1004. for d in dim:
  1005. if d < 0:
  1006. d += len(in_oper.shape)
  1007. collapsed_dims.add(d)
  1008. if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
  1009. assert collapsed_dims.issuperset({2, 3})
  1010. out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
  1011. else:
  1012. out_dim_order = in_oper.dim_order
  1013. out_shape = []
  1014. for i, s in enumerate(in_oper.shape):
  1015. if i not in collapsed_dims:
  1016. out_shape.append(s)
  1017. elif keep_dim:
  1018. out_shape.append(1)
  1019. out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
  1020. inputs = [None] * 3
  1021. inputs[0] = in_id
  1022. inputs[1] = self.add_immediate_int_vector(nnapi_dim)
  1023. inputs[2] = self.add_immediate_int_scalar(keep_dim)
  1024. outputs = [None] * 1
  1025. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1026. self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
  1027. def add_quantize(self, node):
  1028. assert node.inputsSize() == 4
  1029. assert node.outputsSize() == 1
  1030. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1031. if in_oper.dim_order != DimOrder.CHANNELS_LAST:
  1032. raise Exception(
  1033. "Most hardware backends prefer NHWC quantized tensors. "
  1034. "Try setting `t.nnapi_nhwc = True` on your tensor inputs. ")
  1035. _, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
  1036. _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
  1037. _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
  1038. if scalar_type != TorchScalarTypes.QUINT8.value:
  1039. raise Exception(
  1040. "PyTorch NNAPI export only supports quantized tensors "
  1041. "with the quint8 dtype.")
  1042. op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
  1043. out_oper = in_oper._replace(
  1044. op_type=op_type,
  1045. scale=scale,
  1046. zero_point=zero_point,
  1047. )
  1048. inputs = [None] * 1
  1049. inputs[0] = in_id
  1050. outputs = [None] * 1
  1051. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1052. self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
  1053. def add_dequantize(self, node):
  1054. assert node.inputsSize() == 1
  1055. assert node.outputsSize() == 1
  1056. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1057. out_oper = in_oper._replace(
  1058. op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
  1059. scale=0.0,
  1060. zero_point=0,
  1061. )
  1062. inputs = [None] * 1
  1063. inputs[0] = in_id
  1064. outputs = [None] * 1
  1065. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1066. self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
  1067. def add_pointwise_simple_unary_op(self, node, opcode):
  1068. assert node.inputsSize() == 1
  1069. assert node.outputsSize() == 1
  1070. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1071. out_oper = in_oper
  1072. if opcode == NNAPI_OperationCode.LOGISTIC:
  1073. # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
  1074. # must be 1.f / 256 and the zeroPoint must be 0.
  1075. # https://fburl.com/h52stoog
  1076. if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
  1077. out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
  1078. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1079. for idx, dim in enumerate(in_oper.shape):
  1080. if dim == 0:
  1081. self.forward_operand_shape(out_id, idx, in_id, idx)
  1082. inputs = [None] * 1
  1083. inputs[0] = in_id
  1084. outputs = [None] * 1
  1085. outputs[0] = out_id
  1086. self.add_operation(opcode, inputs, outputs)
  1087. def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None):
  1088. """Helper for pointwise binary broadcast ops with superfluous extra args"""
  1089. assert node.outputsSize() == 1
  1090. assert node.inputsAt(0).type().kind() == "TensorType"
  1091. assert node.inputsAt(1).type().kind() == "TensorType"
  1092. if self.has_operand_for_jitval(node.inputsAt(0)):
  1093. in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1094. in1_id, in1_oper = self.get_tensor_operand_or_constant(node.inputsAt(1), in0_oper.dim_order)
  1095. elif self.has_operand_for_jitval(node.inputsAt(1)):
  1096. in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
  1097. in0_id, in0_oper = self.get_tensor_operand_or_constant(node.inputsAt(0), in1_oper.dim_order)
  1098. else:
  1099. raise Exception(f"Can't do a NNAPI binary op: {opcode} on two constants")
  1100. assert in0_oper.op_type == in1_oper.op_type
  1101. in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
  1102. in0_id, in0_oper, in1_id, in1_oper)
  1103. # NOTE: PyTorch and NNAPI have the same broadcast semantics.
  1104. out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
  1105. out_oper = in0_oper._replace(shape=out_shape)
  1106. if qparams is not None:
  1107. scale, zp = qparams
  1108. out_oper = out_oper._replace(scale=scale, zero_point=zp)
  1109. out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1110. for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)):
  1111. if d0 == 1 and d1 == 0:
  1112. self.forward_operand_shape(out_id, idx, in1_id, idx)
  1113. elif d0 == 0 and d1 == 1:
  1114. self.forward_operand_shape(out_id, idx, in0_id, idx)
  1115. elif d0 == 0 and d1 == 0:
  1116. self.flexible_shape_computation_lines.append(
  1117. f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}"
  1118. )
  1119. self.forward_operand_shape(out_id, idx, in0_id, idx)
  1120. inputs = [None] * 3
  1121. inputs[0] = in0_id
  1122. inputs[1] = in1_id
  1123. inputs[2] = self.add_immediate_int_scalar(fuse_code)
  1124. outputs = [None] * 1
  1125. outputs[0] = out_id
  1126. self.add_operation(opcode, inputs, outputs)
  1127. def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
  1128. assert node.inputsSize() == 2
  1129. self._do_add_binary(node, opcode, fuse_code)
  1130. def add_add_sub_op(self, node, opcode, fuse_code):
  1131. assert node.inputsSize() == 3
  1132. _, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
  1133. if alpha != 1:
  1134. raise Exception("NNAPI does not support add/sub with alpha.")
  1135. self._do_add_binary(node, opcode, fuse_code)
  1136. def add_qadd(self, node, opcode, fuse_code):
  1137. assert node.inputsSize() == 4
  1138. _, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
  1139. _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
  1140. self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
  1141. def add_softmax(self, node):
  1142. assert node.inputsSize() == 3
  1143. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1144. _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType")
  1145. out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
  1146. for dim, size in enumerate(in_oper.shape):
  1147. if size == 0:
  1148. self.forward_operand_shape(out_id, dim, in_id, dim)
  1149. inputs = [None] * 3
  1150. inputs[0] = in_id
  1151. inputs[1] = self.add_immediate_float_scalar(1.0) # positive scaling factor of exponent, beta
  1152. inputs[2] = self.add_immediate_int_scalar(softmax_dim)
  1153. outputs = [None] * 1
  1154. outputs[0] = out_id
  1155. self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs)
  1156. def add_hardtanh(self, node):
  1157. assert node.inputsSize() == 3
  1158. assert node.outputsSize() == 1
  1159. in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1160. _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
  1161. _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
  1162. op_map = {
  1163. (-1, 1): NNAPI_OperationCode.RELU1,
  1164. ( 0, 6): NNAPI_OperationCode.RELU6, # noqa: E201
  1165. }
  1166. opcode = op_map.get((min_val, max_val))
  1167. if opcode is None:
  1168. raise Exception("NNAPI only supports hardtanh with args (-1, 1) or (0, 6).")
  1169. inputs = [None] * 1
  1170. inputs[0] = in_id
  1171. outputs = [None] * 1
  1172. outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
  1173. self.add_operation(opcode, inputs, outputs)
  1174. def add_prelu_op(self, node):
  1175. assert node.inputsSize() == 2
  1176. assert node.outputsSize() == 1
  1177. assert node.inputsAt(0).type().kind() == "TensorType"
  1178. assert node.inputsAt(1).type().kind() == "TensorType"
  1179. in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
  1180. w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
  1181. assert len(w_oper.shape) == 1
  1182. assert w_oper.shape[0] > 0
  1183. if w_oper.shape[0] > 1:
  1184. if in_oper.use_nchw():
  1185. # TODO: Support this by adding trailing 1 dims.
  1186. raise Exception("Per-channel PReLU only supports channels_last right now.")
  1187. out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
  1188. for dim, size in enumerate(in_oper.shape):
  1189. if size > 0:
  1190. pass
  1191. elif dim <= 1:
  1192. raise Exception("PReLU requires fixed size for dim 0 and dim 1.")
  1193. else:
  1194. self.forward_operand_shape(out_id, dim, in_id, dim)
  1195. inputs = [None] * 2
  1196. inputs[0] = in_id
  1197. inputs[1] = w_id
  1198. outputs = [None] * 1
  1199. outputs[0] = out_id
  1200. self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
  1201. def add_pool2d_node(self, node, opcode):
  1202. assert node.inputsSize() == 6
  1203. assert node.outputsSize() == 1
  1204. image, kernel, stride, padding, dilation, ceil_mode = node.inputs()
  1205. stride = stride or kernel
  1206. # TODO: Validate ceil_mode semantics.
  1207. args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding, dilation)
  1208. if args.dilation_h != 1 or args.dilation_w != 1:
  1209. raise Exception("NNAPI does not support dilated pooling.")
  1210. image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
  1211. assert len(image_oper.shape) == 4
  1212. out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False)
  1213. use_nchw = image_oper.use_nchw()
  1214. inputs = [None] * 11
  1215. inputs[0] = image_id
  1216. inputs[1] = self.add_immediate_int_scalar(args.pad_l)
  1217. inputs[2] = self.add_immediate_int_scalar(args.pad_r)
  1218. inputs[3] = self.add_immediate_int_scalar(args.pad_t)
  1219. inputs[4] = self.add_immediate_int_scalar(args.pad_b)
  1220. inputs[5] = self.add_immediate_int_scalar(args.stride_w)
  1221. inputs[6] = self.add_immediate_int_scalar(args.stride_h)
  1222. inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
  1223. inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
  1224. inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1225. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1226. outputs = [None] * 1
  1227. outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
  1228. self.add_operation(opcode, inputs, outputs)
  1229. def add_avg_pool2d(self, node):
  1230. assert node.inputsSize() == 7
  1231. assert node.outputsSize() == 1
  1232. image, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override = node.inputs()
  1233. _, count_include_pad_value = self.get_constant_value(count_include_pad)
  1234. _, divisor_override_value = self.get_constant_value(divisor_override)
  1235. if not count_include_pad_value or divisor_override_value:
  1236. raise Exception("NNAPI doesn't support count_include_pad=False or divisor_override")
  1237. args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding)
  1238. image_id, image_oper = self.get_tensor_operand_by_jitval(image)
  1239. assert len(image_oper.shape) == 4
  1240. out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False)
  1241. use_nchw = image_oper.use_nchw()
  1242. inputs = [None] * 11
  1243. inputs[0] = image_id
  1244. inputs[1] = self.add_immediate_int_scalar(args.pad_l)
  1245. inputs[2] = self.add_immediate_int_scalar(args.pad_r)
  1246. inputs[3] = self.add_immediate_int_scalar(args.pad_t)
  1247. inputs[4] = self.add_immediate_int_scalar(args.pad_b)
  1248. inputs[5] = self.add_immediate_int_scalar(args.stride_w)
  1249. inputs[6] = self.add_immediate_int_scalar(args.stride_h)
  1250. inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
  1251. inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
  1252. inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1253. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1254. outputs = [None] * 1
  1255. out_id = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
  1256. self._handle_conv_pool_flexible_input(out_id, image, args, False)
  1257. outputs[0] = out_id
  1258. self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
  1259. def add_adaptive_avg_pool2d(self, node):
  1260. assert node.inputsSize() == 2
  1261. assert node.outputsSize() == 1
  1262. image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
  1263. assert len(image_oper.shape) == 4
  1264. size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
  1265. assert size_ctype.kind() == "ListType"
  1266. assert size_ctype.getElementType().kind() == "IntType"
  1267. if size_arg != [1, 1]:
  1268. raise Exception("NNAPI only supports adaptive_avg_pool2d with output size (1, 1).")
  1269. out_shape = image_oper.shape[0:2] + tuple(size_arg)
  1270. use_nchw = image_oper.use_nchw()
  1271. inputs = [None] * 11
  1272. inputs[0] = image_id
  1273. inputs[1] = self.add_immediate_int_scalar(0)
  1274. inputs[2] = self.add_immediate_int_scalar(0)
  1275. inputs[3] = self.add_immediate_int_scalar(0)
  1276. inputs[4] = self.add_immediate_int_scalar(0)
  1277. inputs[5] = self.add_immediate_int_scalar(1)
  1278. inputs[6] = self.add_immediate_int_scalar(1)
  1279. inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
  1280. inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
  1281. inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1282. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1283. outputs = [None] * 1
  1284. outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
  1285. self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
  1286. def add_upsample_nearest2d(self, node):
  1287. assert node.inputsSize() == 3 or node.inputsSize() == 4
  1288. assert node.outputsSize() == 1
  1289. if node.inputsSize() == 3:
  1290. image, size_jit, scale_jit = node.inputs()
  1291. else:
  1292. image, size_jit, scale_h_jit, scale_w_jit = node.inputs()
  1293. size_ctype, size_arg = self.get_constant_value(size_jit)
  1294. if node.inputsSize() == 3:
  1295. scale_ctype, scale_arg = self.get_constant_value(scale_jit)
  1296. else:
  1297. scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)
  1298. scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit)
  1299. # The only way for the 4-argument overload of upsample_nearest2d to
  1300. # have been added to the graph without error is if the scale_h and
  1301. # scale_w arguments are None
  1302. assert scale_h_ctype.kind() == "NoneType"
  1303. assert scale_w_ctype.kind() == "NoneType"
  1304. scale_ctype = scale_h_ctype
  1305. scale_arg = scale_h_arg
  1306. image_id, image_oper = self.get_tensor_operand_by_jitval(image)
  1307. assert len(image_oper.shape) == 4
  1308. if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
  1309. raise Exception("Size and scale cannot both be non-None.")
  1310. elif size_ctype.kind() != "NoneType":
  1311. assert size_ctype.kind() == "ListType"
  1312. assert size_ctype.getElementType().kind() == "IntType"
  1313. assert scale_ctype.kind() == "NoneType"
  1314. assert scale_arg is None
  1315. assert isinstance(size_arg, list)
  1316. assert size_arg
  1317. assert all(isinstance(val, int) for val in size_arg)
  1318. if len(size_arg) == 1:
  1319. size_arg = size_arg * 2
  1320. assert len(size_arg) == 2
  1321. out_h = size_arg[0]
  1322. out_w = size_arg[1]
  1323. arg_h = self.add_immediate_int_scalar(out_h)
  1324. arg_w = self.add_immediate_int_scalar(out_w)
  1325. elif scale_ctype.kind() != "NoneType":
  1326. assert scale_ctype.kind() == "ListType"
  1327. assert scale_ctype.getElementType().kind() == "FloatType"
  1328. assert size_ctype.kind() == "NoneType"
  1329. assert size_arg is None
  1330. assert isinstance(scale_arg, list)
  1331. assert scale_arg
  1332. assert all(isinstance(val, float) for val in scale_arg)
  1333. if len(scale_arg) == 1:
  1334. scale_arg = scale_arg * 2
  1335. assert len(scale_arg) == 2
  1336. out_h = int(scale_arg[0] * image_oper.shape[2])
  1337. out_w = int(scale_arg[1] * image_oper.shape[3])
  1338. arg_h = self.add_immediate_float_scalar(scale_arg[0])
  1339. arg_w = self.add_immediate_float_scalar(scale_arg[1])
  1340. else:
  1341. raise Exception("Size and scale cannot both be None.")
  1342. out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
  1343. use_nchw = image_oper.use_nchw()
  1344. out_id = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
  1345. if image_oper.shape[0] == 0 or image_oper.shape[1] == 0:
  1346. raise Exception("Flexible batch or channels not supported")
  1347. # Handle variable input size
  1348. for dim in (2, 3): # h, w indices
  1349. if image_oper.shape[dim] == 0:
  1350. if size_ctype.kind() != "NoneType":
  1351. self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
  1352. elif scale_ctype.kind() != "NoneType":
  1353. self.compute_operand_shape(out_id, dim, f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})")
  1354. else:
  1355. raise Exception("Size and scale cannot both be None.")
  1356. inputs = [None] * 4
  1357. inputs[0] = image_id
  1358. inputs[1] = arg_w
  1359. inputs[2] = arg_h
  1360. inputs[3] = self.add_immediate_bool_scalar(use_nchw)
  1361. outputs = [None] * 1
  1362. outputs[0] = out_id
  1363. self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
  1364. def add_addmm(self, node):
  1365. assert node.inputsSize() == 5
  1366. assert node.outputsSize() == 1
  1367. jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
  1368. for jitval in (jit_beta, jit_alpha):
  1369. scale_ctype, scale_value = self.get_constant_value(jitval)
  1370. assert scale_ctype.kind() in ("IntType", "FloatType")
  1371. if scale_value != 1:
  1372. raise Exception("NNAPI Fully-Connected does not support alpha and beta.")
  1373. self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
  1374. def add_linear(self, node):
  1375. assert node.inputsSize() == 3
  1376. assert node.outputsSize() == 1
  1377. jit_input, jit_weight, jit_bias = node.inputs()
  1378. self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
  1379. def add_addmm_or_linear(self, node, transpose_weight, jit_input, jit_weight, jit_bias):
  1380. input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
  1381. bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
  1382. assert len(input_oper.shape) == 2
  1383. assert len(bias_oper.shape) == 1
  1384. # TODO: Transform at load time to share weights with CPU model.
  1385. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
  1386. assert len(weight_tensor.shape) == 2
  1387. if transpose_weight:
  1388. nnapi_weight_tensor = weight_tensor.t().contiguous()
  1389. else:
  1390. nnapi_weight_tensor = weight_tensor.contiguous()
  1391. weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
  1392. weight_oper = self.operands[weight_id]
  1393. out_shape = (input_oper.shape[0], weight_oper.shape[0])
  1394. out_id = self.add_tensor_operand(node.outputsAt(0), input_oper._replace(shape=out_shape))
  1395. if input_oper.shape[0] == 0:
  1396. self.forward_operand_shape(out_id, 0, input_id, 0)
  1397. inputs = [None] * 4
  1398. inputs[0] = input_id
  1399. inputs[1] = weight_id
  1400. inputs[2] = bias_id
  1401. inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1402. outputs = [None] * 1
  1403. outputs[0] = out_id
  1404. self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
  1405. def add_qlinear(self, node):
  1406. assert node.inputsSize() == 4
  1407. assert node.outputsSize() == 1
  1408. (
  1409. jit_input,
  1410. jit_packed_weight,
  1411. jit_scale,
  1412. jit_zero_point,
  1413. ) = node.inputs()
  1414. input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
  1415. # TODO: Support automatic reshape
  1416. assert len(input_oper.shape) == 2
  1417. _, out_scale = self.get_constant_value(jit_scale, "FloatType")
  1418. _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
  1419. weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
  1420. assert weight_ctype.name() == "LinearPackedParamsBase"
  1421. raw_weight, raw_bias = packed_weight.__getstate__()[0]
  1422. assert raw_bias is not None
  1423. assert len(raw_weight.shape) == 2
  1424. assert len(raw_bias.shape) == 1
  1425. assert raw_bias.shape[0] == raw_weight.shape[0]
  1426. assert raw_weight.shape[1] == input_oper.shape[1]
  1427. assert raw_weight.qscheme() == torch.per_tensor_affine
  1428. if raw_weight.dtype == torch.quint8:
  1429. unsigned_weight = raw_weight
  1430. else:
  1431. assert raw_weight.dtype == torch.qint8
  1432. unsigned_weight = torch._make_per_tensor_quantized_tensor(
  1433. (raw_weight.int_repr().int() + 128).to(torch.uint8),
  1434. scale=raw_weight.q_scale(),
  1435. zero_point=raw_weight.q_zero_point() + 128)
  1436. weight_scale = unsigned_weight.q_scale()
  1437. bias_scale = input_oper.scale * weight_scale
  1438. int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
  1439. bias_id = self.add_tensor_operand_for_weight(int_bias)
  1440. multiplier = input_oper.scale * weight_scale / out_scale
  1441. assert multiplier > 0
  1442. if multiplier >= 1:
  1443. raise Exception(
  1444. "Quantized convolution multiplier is greater than 1. "
  1445. "This is supported by NNAPI, but not by most hardware backends. "
  1446. "Try training a model without quantization-aware training. ")
  1447. # TODO: Transform at load time to share weights with CPU model.
  1448. nnapi_weight_tensor = unsigned_weight.contiguous()
  1449. weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
  1450. weight_oper = self.operands[weight_id]
  1451. out_shape = (input_oper.shape[0], weight_oper.shape[0])
  1452. out_oper = input_oper._replace(
  1453. shape=out_shape,
  1454. scale=out_scale,
  1455. zero_point=out_zero_point,
  1456. )
  1457. inputs = [None] * 4
  1458. inputs[0] = input_id
  1459. inputs[1] = weight_id
  1460. inputs[2] = bias_id
  1461. inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
  1462. outputs = [None] * 1
  1463. outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
  1464. self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
  1465. def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
  1466. ctype, value = self.get_constant_value(jit_bias)
  1467. if ctype.kind() == "NoneType":
  1468. bias_idx = 1 if transpose else 0
  1469. nnapi_bias_tensor = torch.zeros(weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype)
  1470. bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
  1471. bias_oper = self.operands[bias_id]
  1472. return bias_id, bias_oper
  1473. else:
  1474. return self.get_tensor_operand_for_weight(jit_bias)
  1475. def add_conv2d(self, node):
  1476. assert node.inputsSize() == 7
  1477. assert node.outputsSize() == 1
  1478. (
  1479. jit_image,
  1480. jit_weight,
  1481. jit_bias,
  1482. jit_stride,
  1483. jit_pad,
  1484. jit_dilation,
  1485. jit_groups,
  1486. ) = node.inputs()
  1487. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
  1488. bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
  1489. args = self.get_conv_pool_args_2d_from_jit(
  1490. weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups)
  1491. return self.add_conv2d_common(
  1492. node.outputsAt(0),
  1493. 0.0,
  1494. 0,
  1495. jit_image,
  1496. weight_tensor,
  1497. bias_id,
  1498. args,
  1499. False, # transpose
  1500. NNAPI_FuseCode.FUSED_NONE,
  1501. )
  1502. def add_conv_underscore(self, node):
  1503. assert node.inputsSize() == 13
  1504. assert node.outputsSize() == 1
  1505. (
  1506. jit_image,
  1507. jit_weight,
  1508. jit_bias,
  1509. jit_stride,
  1510. jit_pad,
  1511. jit_dilation,
  1512. jit_transpose,
  1513. _,
  1514. jit_groups,
  1515. _,
  1516. _,
  1517. _,
  1518. _,
  1519. ) = node.inputs()
  1520. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
  1521. _, transpose = self.get_constant_value(jit_transpose)
  1522. bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
  1523. args = self.get_conv_pool_args_2d_from_jit(
  1524. weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups)
  1525. return self.add_conv2d_common(
  1526. node.outputsAt(0),
  1527. 0.0,
  1528. 0,
  1529. jit_image,
  1530. weight_tensor,
  1531. bias_id,
  1532. args,
  1533. transpose,
  1534. NNAPI_FuseCode.FUSED_NONE,
  1535. )
  1536. def add_log_softmax(self, node):
  1537. assert node.inputsSize() == 3
  1538. assert node.outputsSize() == 1
  1539. (
  1540. jit_input,
  1541. jit_dim,
  1542. jit_half_to_float
  1543. ) = node.inputs()
  1544. input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
  1545. _, dim = self.get_constant_value(jit_dim, "IntType")
  1546. out_shape = input_oper.shape
  1547. inputs = [None] * 3
  1548. inputs[0] = input_id
  1549. # specifying 1 as the scaling factor for the exponent, beta
  1550. inputs[1] = self.add_immediate_float_scalar(1)
  1551. inputs[2] = self.add_immediate_int_scalar(dim)
  1552. outputs = [None] * 1
  1553. outputs[0] = self.add_tensor_operand(node.outputsAt(0), input_oper._replace(shape=out_shape))
  1554. self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs)
  1555. def add_qconv2d(self, node, fuse_code, transpose=False):
  1556. assert node.inputsSize() == 4
  1557. assert node.outputsSize() == 1
  1558. (
  1559. jit_image,
  1560. jit_packed_weight,
  1561. jit_scale,
  1562. jit_zero_point,
  1563. ) = node.inputs()
  1564. _, out_scale = self.get_constant_value(jit_scale, "FloatType")
  1565. _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
  1566. weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
  1567. assert weight_ctype.name() == "Conv2dPackedParamsBase"
  1568. (
  1569. pack_version,
  1570. tensors,
  1571. opt_tensors,
  1572. ) = packed_weight.__getstate__()[0]
  1573. assert pack_version == "2"
  1574. packed_config, raw_weight = tensors
  1575. raw_bias, = opt_tensors
  1576. assert raw_bias is not None
  1577. args = self.get_conv_pool_args_2d_from_pack(raw_weight.shape[2:4], packed_config)
  1578. assert raw_weight.qscheme() == torch.per_tensor_affine
  1579. if raw_weight.dtype == torch.quint8:
  1580. unsigned_weight = raw_weight
  1581. else:
  1582. assert raw_weight.dtype == torch.qint8
  1583. unsigned_weight = torch._make_per_tensor_quantized_tensor(
  1584. (raw_weight.int_repr().int() + 128).to(torch.uint8),
  1585. scale=raw_weight.q_scale(),
  1586. zero_point=raw_weight.q_zero_point() + 128)
  1587. weight_scale = unsigned_weight.q_scale()
  1588. _, image_oper = self.get_tensor_operand_by_jitval(jit_image)
  1589. bias_scale = image_oper.scale * weight_scale
  1590. int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
  1591. bias_id = self.add_tensor_operand_for_weight(int_bias)
  1592. multiplier = image_oper.scale * weight_scale / out_scale
  1593. assert multiplier > 0
  1594. if multiplier >= 1:
  1595. raise Exception(
  1596. "Quantized convolution multiplier is greater than 1. "
  1597. "This is supported by NNAPI, but not by most hardware backends. "
  1598. "Try training a model without quantization-aware training. ")
  1599. return self.add_conv2d_common(
  1600. node.outputsAt(0),
  1601. out_scale,
  1602. out_zero_point,
  1603. jit_image,
  1604. unsigned_weight,
  1605. bias_id,
  1606. args,
  1607. transpose,
  1608. fuse_code,
  1609. )
  1610. def add_conv2d_common(
  1611. self,
  1612. jit_out,
  1613. out_scale,
  1614. out_zero_point,
  1615. jit_image,
  1616. weight_tensor,
  1617. bias_id,
  1618. args,
  1619. transpose,
  1620. fuse_code):
  1621. image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
  1622. in_c = image_oper.shape[1]
  1623. if args.group == 1:
  1624. # Full convolution
  1625. depthwise = False
  1626. if transpose:
  1627. weight_permutation = (1, 2, 3, 0)
  1628. else:
  1629. weight_permutation = (0, 2, 3, 1)
  1630. elif args.group == in_c:
  1631. # Depthwise convolution
  1632. depthwise = True
  1633. weight_permutation = (1, 2, 3, 0)
  1634. else:
  1635. raise Exception("Group convolution not supported yet.")
  1636. # TODO: Transform at load time to share weights with CPU model.
  1637. nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
  1638. weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
  1639. weight_oper = self.operands[weight_id]
  1640. bias_oper = self.operands[bias_id]
  1641. if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
  1642. assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
  1643. assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
  1644. elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
  1645. assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
  1646. assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
  1647. assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
  1648. assert bias_oper.zero_point == 0
  1649. else:
  1650. raise Exception(
  1651. "Unsupported input type for conv2d: {}"
  1652. .format(image_oper.op_type))
  1653. assert len(image_oper.shape) == 4
  1654. assert len(weight_oper.shape) == 4
  1655. assert len(bias_oper.shape) == 1
  1656. if depthwise:
  1657. # Depthwise convolution
  1658. one, kern_h, kern_w, out_c = weight_oper.shape
  1659. assert one == 1
  1660. assert out_c % in_c == 0
  1661. channel_multiplier = out_c // in_c
  1662. assert channel_multiplier == 1 # Don't support multiplier
  1663. assert out_c == in_c
  1664. else:
  1665. # Full convolution
  1666. out_c, kern_h, kern_w, kern_d = weight_oper.shape
  1667. assert kern_d == in_c
  1668. assert out_c == bias_oper.shape[0]
  1669. use_nchw = image_oper.use_nchw()
  1670. if depthwise:
  1671. num_args = 12
  1672. opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
  1673. else:
  1674. num_args = 11
  1675. if transpose:
  1676. opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
  1677. else:
  1678. opcode = NNAPI_OperationCode.CONV_2D
  1679. inputs = [None] * num_args
  1680. inputs[0] = image_id
  1681. inputs[1] = weight_id
  1682. inputs[2] = bias_id
  1683. inputs[3] = self.add_immediate_int_scalar(args.pad_l)
  1684. inputs[4] = self.add_immediate_int_scalar(args.pad_r)
  1685. inputs[5] = self.add_immediate_int_scalar(args.pad_t)
  1686. inputs[6] = self.add_immediate_int_scalar(args.pad_b)
  1687. inputs[7] = self.add_immediate_int_scalar(args.stride_w)
  1688. inputs[8] = self.add_immediate_int_scalar(args.stride_h)
  1689. if depthwise:
  1690. inputs[9] = self.add_immediate_int_scalar(1)
  1691. inputs[10] = self.add_immediate_int_scalar(fuse_code)
  1692. inputs[11] = self.add_immediate_bool_scalar(use_nchw)
  1693. else:
  1694. inputs[9] = self.add_immediate_int_scalar(fuse_code)
  1695. inputs[10] = self.add_immediate_bool_scalar(use_nchw)
  1696. outputs = [None] * 1
  1697. out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
  1698. out_oper = image_oper._replace(
  1699. shape=out_shape,
  1700. scale=out_scale,
  1701. zero_point=out_zero_point,
  1702. )
  1703. out_id = self.add_tensor_operand(jit_out, out_oper)
  1704. self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose)
  1705. outputs[0] = out_id
  1706. self.add_operation(opcode, inputs, outputs)
  1707. def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose):
  1708. image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
  1709. batch, in_ch, in_h, in_w = image_oper.shape
  1710. if batch == 0:
  1711. self.forward_operand_shape(out_id, 0, image_id, 0)
  1712. if in_ch == 0:
  1713. raise Exception("Input channels can't be flexible")
  1714. # H & W
  1715. if transpose:
  1716. if in_h == 0:
  1717. self.compute_operand_shape(
  1718. out_id,
  1719. 2,
  1720. f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}"
  1721. )
  1722. if in_w == 0:
  1723. self.compute_operand_shape(
  1724. out_id,
  1725. 3,
  1726. f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}"
  1727. )
  1728. else:
  1729. if in_h == 0:
  1730. self.compute_operand_shape(
  1731. out_id,
  1732. 2,
  1733. f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1"
  1734. )
  1735. if in_w == 0:
  1736. self.compute_operand_shape(
  1737. out_id,
  1738. 3,
  1739. f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1"
  1740. )
  1741. def serialize_model(module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False):
  1742. """Convert to NNAPI and serialize torchscript module:
  1743. Parameters:
  1744. module: Torchscript module to convert
  1745. inputs: Tensors used to specify input details for NNAPI
  1746. config (optional): Optional config to attach to module
  1747. return_shapes (optional): Specify shape of outputs if
  1748. your module uses runtime flexible shapes to set output
  1749. buffer size for NNAPI
  1750. use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values
  1751. """
  1752. return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(module, inputs, return_shapes)