common_quantization.py 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392
  1. r"""Importing this file includes common utility methods and base clases for
  2. checking quantization api and properties of resulting modules.
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
  8. import torch.ao.nn.quantized as nnq
  9. import torch.ao.nn.quantized.dynamic as nnqd
  10. from torch.ao.nn.intrinsic import _FusedModule
  11. import torch.distributed as dist
  12. from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
  13. from torch.ao.quantization import (
  14. QuantType,
  15. default_dynamic_qat_qconfig,
  16. default_embedding_qat_qconfig,
  17. default_symmetric_qnnpack_qat_qconfig,
  18. )
  19. from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \
  20. default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
  21. propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
  22. get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize
  23. from torch.ao.quantization.quantization_mappings import (
  24. get_default_dynamic_quant_module_mappings,
  25. get_default_qconfig_propagation_list,
  26. get_default_qat_module_mappings,
  27. )
  28. from torch.testing._internal.common_quantized import (
  29. override_quantized_engine,
  30. )
  31. from torch.jit.mobile import _load_for_lite_interpreter
  32. try:
  33. # graph mode quantization based on fx
  34. from torch.ao.quantization.quantize_fx import (
  35. prepare_fx,
  36. prepare_qat_fx,
  37. convert_fx,
  38. convert_to_reference_fx,
  39. )
  40. from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
  41. from torch.fx.graph import Node
  42. from torch.fx import GraphModule
  43. HAS_FX = True
  44. except ImportError:
  45. HAS_FX = False
  46. import copy
  47. import io
  48. import functools
  49. import time
  50. import os
  51. import unittest
  52. import numpy as np
  53. from torch.testing import FileCheck
  54. from typing import Callable, Tuple, Dict, Any, Union, Type, Optional
  55. class NodeSpec:
  56. ''' Used for checking GraphModule Node
  57. '''
  58. def __init__(self, op, target):
  59. '''
  60. op: call_function | call_module
  61. target:
  62. for call_function, target would be a function
  63. for call_module, target would be the type of PyTorch module
  64. '''
  65. self.op = op
  66. self.target = target
  67. @classmethod
  68. def call_function(cls, target):
  69. return NodeSpec('call_function', target)
  70. @classmethod
  71. def call_method(cls, target):
  72. return NodeSpec('call_method', target)
  73. @classmethod
  74. def call_module(cls, target):
  75. return NodeSpec('call_module', target)
  76. def __hash__(self):
  77. return hash((self.op, self.target))
  78. def __eq__(self, other):
  79. if not isinstance(other, NodeSpec):
  80. return NotImplemented
  81. return self.op == other.op and self.target == other.target
  82. def __repr__(self):
  83. return repr(self.op) + " " + repr(self.target)
  84. def get_supported_device_types():
  85. return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu']
  86. def test_only_eval_fn(model, calib_data):
  87. r"""
  88. Default evaluation function takes a torch.utils.data.Dataset or a list of
  89. input Tensors and run the model on the dataset
  90. """
  91. for inp in calib_data:
  92. output = model(*inp)
  93. _default_loss_fn = torch.nn.CrossEntropyLoss()
  94. def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn):
  95. r"""
  96. Default train function takes a torch.utils.data.Dataset and train the model
  97. on the dataset
  98. """
  99. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  100. train_loss, correct, total = 0, 0, 0
  101. for i in range(10):
  102. model.train()
  103. for data, target in train_data:
  104. optimizer.zero_grad()
  105. output = model(data)
  106. loss = loss_fn(output, target)
  107. loss.backward()
  108. optimizer.step()
  109. train_loss += loss.item()
  110. _, predicted = torch.max(output, 1)
  111. total += target.size(0)
  112. correct += (predicted == target).sum().item()
  113. return train_loss, correct, total
  114. class AverageMeter:
  115. """Computes and stores the average and current value"""
  116. def __init__(self, name, fmt=':f'):
  117. self.name = name
  118. self.fmt = fmt
  119. self.reset()
  120. def reset(self):
  121. self.val = 0
  122. self.avg = 0
  123. self.sum = 0
  124. self.count = 0
  125. def update(self, val, n=1):
  126. self.val = val
  127. self.sum += val * n
  128. self.count += n
  129. self.avg = self.sum / self.count
  130. def __str__(self):
  131. fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
  132. return fmtstr.format(**self.__dict__)
  133. def accuracy(output, target, topk=(1,)):
  134. """Computes the accuracy over the k top predictions for the specified values of k"""
  135. with torch.no_grad():
  136. maxk = max(topk)
  137. batch_size = target.size(0)
  138. _, pred = output.topk(maxk, 1, True, True)
  139. pred = pred.t()
  140. correct = pred.eq(target.view(1, -1).expand_as(pred))
  141. res = []
  142. for k in topk:
  143. correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
  144. res.append(correct_k.mul_(100.0 / batch_size))
  145. return res
  146. def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
  147. model.train()
  148. cnt = 0
  149. for image, target in data_loader:
  150. start_time = time.time()
  151. print('.', end='')
  152. cnt += 1
  153. image, target = image.to(device), target.to(device)
  154. output = model(image)
  155. loss = criterion(output, target)
  156. optimizer.zero_grad()
  157. loss.backward()
  158. optimizer.step()
  159. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  160. if cnt >= ntrain_batches:
  161. return
  162. return
  163. def ddp_setup(rank, world_size):
  164. os.environ['MASTER_ADDR'] = 'localhost'
  165. os.environ['MASTER_PORT'] = '12355'
  166. # initialize the process group
  167. dist.init_process_group("gloo", rank=rank, world_size=world_size)
  168. def ddp_cleanup():
  169. dist.destroy_process_group()
  170. def run_ddp(rank, world_size, prepared):
  171. ddp_setup(rank, world_size)
  172. prepared.cuda()
  173. prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank])
  174. prepared.to(rank)
  175. model_with_ddp = prepared
  176. optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
  177. train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)
  178. ddp_cleanup()
  179. def convert_dynamic(module):
  180. convert(module, get_default_dynamic_quant_module_mappings(), inplace=True)
  181. def prepare_dynamic(model, qconfig_dict=None):
  182. propagate_qconfig_(model, qconfig_dict)
  183. def _make_conv_test_input(
  184. batch_size, in_channels_per_group, input_feature_map_size,
  185. out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale,
  186. W_zero_point, use_bias, use_channelwise,
  187. ):
  188. in_channels = in_channels_per_group * groups
  189. out_channels = out_channels_per_group * groups
  190. (X_value_min, X_value_max) = (0, 4)
  191. X_init = torch.randint(
  192. X_value_min, X_value_max,
  193. (batch_size, in_channels,) + input_feature_map_size)
  194. X = X_scale * (X_init - X_zero_point).float()
  195. X_q = torch.quantize_per_tensor(
  196. X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)
  197. W_scale = W_scale * out_channels
  198. W_zero_point = W_zero_point * out_channels
  199. # Resize W_scale and W_zero_points arrays equal to out_channels
  200. W_scale = W_scale[:out_channels]
  201. W_zero_point = W_zero_point[:out_channels]
  202. # For testing, we use small values for weights and for activations so that
  203. # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
  204. # qconv implementation and if there is no overflow.
  205. # In reference we can't exactly match the results with reference.
  206. # Please see the comment in qconv implementation file
  207. # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
  208. (W_value_min, W_value_max) = (-5, 5)
  209. # The operator expects them in the format
  210. # (out_channels, in_channels/groups,) + kernel_size
  211. W_init = torch.randint(
  212. W_value_min, W_value_max,
  213. (out_channels, in_channels_per_group,) + kernel_size)
  214. b_init = torch.randint(0, 10, (out_channels,))
  215. if use_channelwise:
  216. W_shape = (-1, 1) + (1,) * len(kernel_size)
  217. W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
  218. W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
  219. W = W_scales_tensor.reshape(*W_shape) * (
  220. W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
  221. b = X_scale * W_scales_tensor * b_init.float()
  222. W_q = torch.quantize_per_channel(
  223. W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0,
  224. dtype=torch.qint8)
  225. else:
  226. W = W_scale[0] * (W_init - W_zero_point[0]).float()
  227. b = X_scale * W_scale[0] * b_init.float()
  228. W_q = torch.quantize_per_tensor(
  229. W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8)
  230. return (X, X_q, W, W_q, b if use_bias else None)
  231. def _make_conv_add_extra_input_tensor(scale, zero_point, sizes):
  232. (X_value_min, X_value_max) = (0, 4)
  233. X_init = torch.randint(
  234. X_value_min,
  235. X_value_max,
  236. sizes # Infer the size of tensor to do the add
  237. )
  238. X = scale * (X_init - zero_point).float()
  239. X_q = torch.quantize_per_tensor(
  240. X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
  241. return X, X_q
  242. def skipIfNoFBGEMM(fn):
  243. reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.'
  244. if isinstance(fn, type):
  245. if 'fbgemm' not in torch.backends.quantized.supported_engines:
  246. fn.__unittest_skip__ = True
  247. fn.__unittest_skip_why__ = reason
  248. return fn
  249. @functools.wraps(fn)
  250. def wrapper(*args, **kwargs):
  251. if 'fbgemm' not in torch.backends.quantized.supported_engines:
  252. raise unittest.SkipTest(reason)
  253. else:
  254. fn(*args, **kwargs)
  255. return wrapper
  256. def skipIfNoQNNPACK(fn):
  257. reason = 'Quantized operations require QNNPACK.'
  258. if isinstance(fn, type):
  259. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  260. fn.__unittest_skip__ = True
  261. fn.__unittest_skip_why__ = reason
  262. return fn
  263. @functools.wraps(fn)
  264. def wrapper(*args, **kwargs):
  265. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  266. raise unittest.SkipTest(reason)
  267. else:
  268. fn(*args, **kwargs)
  269. return wrapper
  270. @functools.wraps(fn)
  271. def wrapper(*args, **kwargs):
  272. if not torch.onnx._CAFFE2_ATEN_FALLBACK:
  273. raise unittest.SkipTest(reason)
  274. else:
  275. fn(*args, **kwargs)
  276. return wrapper
  277. def withQNNPACKBackend(fn):
  278. # TODO(future PR): consider combining with skipIfNoQNNPACK,
  279. # will require testing of existing callsites
  280. reason = 'Quantized operations require QNNPACK.'
  281. if isinstance(fn, type):
  282. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  283. fn.__unittest_skip__ = True
  284. fn.__unittest_skip_why__ = reason
  285. return fn
  286. @functools.wraps(fn)
  287. def wrapper(*args, **kwargs):
  288. if 'qnnpack' not in torch.backends.quantized.supported_engines:
  289. raise unittest.SkipTest(reason)
  290. with override_quantized_engine('qnnpack'):
  291. fn(*args, **kwargs)
  292. return wrapper
  293. def skipIfNoONEDNN(fn):
  294. reason = 'Quantized operations require ONEDNN.'
  295. if isinstance(fn, type):
  296. if 'onednn' not in torch.backends.quantized.supported_engines:
  297. fn.__unittest_skip__ = True
  298. fn.__unittest_skip_why__ = reason
  299. return fn
  300. @functools.wraps(fn)
  301. def wrapper(*args, **kwargs):
  302. if 'onednn' not in torch.backends.quantized.supported_engines:
  303. raise unittest.SkipTest(reason)
  304. else:
  305. fn(*args, **kwargs)
  306. return wrapper
  307. try:
  308. import torchvision # noqa: F401
  309. HAS_TORCHVISION = True
  310. except ImportError:
  311. HAS_TORCHVISION = False
  312. skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
  313. def get_script_module(model, tracing, data):
  314. return torch.jit.trace(model, data) if tracing else torch.jit.script(model)
  315. def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
  316. """
  317. Convert lengths to offsets for embedding_bag
  318. """
  319. tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
  320. tt[1:] = t
  321. tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
  322. if use_begin_offset:
  323. return tt[:-1]
  324. return tt[1:]
  325. # QuantizationTestCase used as a base class for testing quantization on modules
  326. class QuantizationTestCase(TestCase):
  327. def setUp(self):
  328. super().setUp()
  329. self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)]
  330. self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)]
  331. self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)]
  332. for _ in range(2)]
  333. self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)]
  334. for _ in range(2)]
  335. self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)]
  336. for _ in range(2)]
  337. self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float),
  338. torch.randint(0, 1, (1,), dtype=torch.long)]
  339. for _ in range(2)]
  340. self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float),
  341. torch.randint(0, 1, (1,), dtype=torch.long)]
  342. for _ in range(2)]
  343. self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float),
  344. torch.randint(0, 1, (1,), dtype=torch.long)]
  345. for _ in range(2)]
  346. self.img_data_dict = {1 : self.img_data_1d,
  347. 2 : self.img_data_2d,
  348. 3 : self.img_data_3d}
  349. # Quant types that produce statically quantized ops
  350. self.static_quant_types = [QuantType.STATIC, QuantType.QAT]
  351. # All quant types for (fx based) graph mode quantization
  352. self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT]
  353. def checkNoPrepModules(self, module):
  354. r"""Checks the module does not contain child
  355. modules for quantization prepration, e.g.
  356. quant, dequant and observer
  357. """
  358. self.assertFalse(hasattr(module, 'quant'))
  359. self.assertFalse(hasattr(module, 'dequant'))
  360. def checkNoQconfig(self, module):
  361. r"""Checks the module does not contain qconfig
  362. """
  363. self.assertFalse(hasattr(module, 'qconfig'))
  364. for child in module.children():
  365. self.checkNoQconfig(child)
  366. def checkHasPrepModules(self, module):
  367. r"""Checks the module contains child
  368. modules for quantization prepration, e.g.
  369. quant, dequant and observer
  370. """
  371. self.assertTrue(hasattr(module, 'module'))
  372. self.assertTrue(hasattr(module, 'quant'))
  373. self.assertTrue(hasattr(module, 'dequant'))
  374. def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
  375. r"""Checks the module or module's leaf descendants
  376. have observers in preperation for quantization
  377. """
  378. if propagate_qconfig_list is None:
  379. propagate_qconfig_list = get_default_qconfig_propagation_list()
  380. if prepare_custom_config_dict is None:
  381. prepare_custom_config_dict = {}
  382. float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
  383. # check if a module is a leaf module, ignoring activation_post_process attribute
  384. def is_leaf_module(module):
  385. submodule_name_count = 0
  386. for name, _ in module.named_children():
  387. if name != 'activation_post_process':
  388. submodule_name_count += 1
  389. return submodule_name_count == 0
  390. if hasattr(module, 'qconfig') and module.qconfig is not None and \
  391. ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
  392. and type(module) in propagate_qconfig_list) or
  393. type(module) in float_to_observed_module_class_mapping.keys()) and \
  394. not isinstance(module, torch.ao.quantization.DeQuantStub):
  395. self.assertTrue(hasattr(module, 'activation_post_process'),
  396. 'module: ' + str(type(module)) + ' do not have observer')
  397. # we don't need to check observers for child modules of the
  398. # qat modules
  399. if type(module) not in get_default_qat_module_mappings().values() and \
  400. type(module) not in float_to_observed_module_class_mapping.values() and \
  401. not isinstance(module, _FusedModule):
  402. for child in module.children():
  403. if type(child) in [nn.Dropout]:
  404. continue
  405. self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)
  406. def checkQuantDequant(self, mod):
  407. r"""Checks that mod has nn.Quantize and
  408. nn.DeQuantize submodules inserted
  409. """
  410. self.assertEqual(type(mod.quant), nnq.Quantize)
  411. self.assertEqual(type(mod.dequant), nnq.DeQuantize)
  412. def checkWrappedQuantizedLinear(self, mod):
  413. r"""Checks that mod has been swapped for an nnq.Linear
  414. module, the bias is qint32, and that the module
  415. has Quantize and DeQuantize submodules
  416. """
  417. self.assertEqual(type(mod.module), nnq.Linear)
  418. self.checkQuantDequant(mod)
  419. def checkQuantizedLinear(self, mod):
  420. self.assertEqual(type(mod), nnq.Linear)
  421. def checkDynamicQuantizedLinear(self, mod, dtype):
  422. r"""Checks that mod has been swapped for an nnqd.Linear
  423. module, the bias is float.
  424. """
  425. self.assertEqual(type(mod), nnqd.Linear)
  426. self.assertEqual(mod._packed_params.dtype, dtype)
  427. def checkDynamicQuantizedLinearRelu(self, mod, dtype):
  428. r"""Checks that mod has been swapped for an nnqd.Linear
  429. module, the bias is float.
  430. """
  431. self.assertEqual(type(mod), nniqd.LinearReLU)
  432. self.assertEqual(mod._packed_params.dtype, dtype)
  433. def check_eager_serialization(self, ref_model, loaded_model, x):
  434. # Check state dict serialization and torch.save APIs
  435. model_dict = ref_model.state_dict()
  436. b = io.BytesIO()
  437. torch.save(model_dict, b)
  438. b.seek(0)
  439. loaded_dict = torch.load(b)
  440. loaded_model.load_state_dict(loaded_dict)
  441. ref_out = ref_model(*x)
  442. load_out = loaded_model(*x)
  443. def check_outputs(ref_out, load_out):
  444. self.assertEqual(ref_out[0], load_out[0])
  445. if isinstance(ref_out[1], tuple):
  446. self.assertEqual(ref_out[1][0], load_out[1][0])
  447. self.assertEqual(ref_out[1][1], load_out[1][1])
  448. else:
  449. self.assertEqual(ref_out[1], load_out[1])
  450. check_outputs(ref_out, load_out)
  451. b = io.BytesIO()
  452. torch.save(ref_model, b)
  453. b.seek(0)
  454. loaded = torch.load(b)
  455. load_out = loaded(*x)
  456. check_outputs(ref_out, load_out)
  457. def check_weight_bias_api(self, ref_model, weight_keys, bias_keys):
  458. weight = ref_model.get_weight()
  459. bias = ref_model.get_bias()
  460. self.assertEqual(weight_keys ^ weight.keys(), set())
  461. self.assertEqual(bias_keys ^ bias.keys(), set())
  462. def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype):
  463. r"""Checks that mod has been swapped for an nnqd.LSTM type
  464. module, the bias is float.
  465. """
  466. wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
  467. self.assertEqual(type(mod), reference_module_type)
  468. for packed_params in mod._all_weight_values:
  469. self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])
  470. def checkLinear(self, mod):
  471. self.assertEqual(type(mod), torch.nn.Linear)
  472. def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype):
  473. r"""Checks that mod has been swapped for an nnqd.Linear
  474. module, the bias is float.
  475. """
  476. wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
  477. self.assertEqual(type(mod), reference_module_type)
  478. if hasattr(mod, '_all_weight_values'):
  479. for packed_params in mod._all_weight_values:
  480. self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])
  481. def checkScriptable(self, orig_mod, calib_data, check_save_load=False):
  482. scripted = torch.jit.script(orig_mod)
  483. self._checkScriptable(orig_mod, scripted, calib_data, check_save_load)
  484. # Use first calib_data entry as trace input
  485. traced = torch.jit.trace(orig_mod, calib_data[0])
  486. self._checkScriptable(orig_mod, traced, calib_data, check_save_load)
  487. # Call this twice: once for a scripted module and once for a traced module
  488. def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load):
  489. self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data)
  490. # Test save/load
  491. buffer = io.BytesIO()
  492. torch.jit.save(script_mod, buffer)
  493. buffer.seek(0)
  494. loaded_mod = torch.jit.load(buffer)
  495. # Pending __get_state_ and __set_state__ support
  496. # See tracking task https://github.com/pytorch/pytorch/issues/23984
  497. if check_save_load:
  498. self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data)
  499. def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
  500. for inp in calib_data:
  501. ref_output = orig_mod(*inp)
  502. scripted_output = test_mod(*inp)
  503. self.assertEqual(scripted_output, ref_output)
  504. def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False,
  505. check=True, eval_mode=True, dynamic=False, qconfig=None):
  506. if debug:
  507. print('Testing:', str(module))
  508. qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
  509. if eval_mode:
  510. module = module.eval()
  511. if dynamic:
  512. qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig}
  513. model = get_script_module(module, tracing, inputs[0]).eval()
  514. if debug:
  515. print('input graph:', model.graph)
  516. models = {}
  517. outputs = {}
  518. for debug in [True, False]:
  519. if dynamic:
  520. models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug)
  521. # make sure it runs
  522. outputs[debug] = models[debug](inputs)
  523. else:
  524. # module under test can contain in-place ops, and we depend on
  525. # input data staying constant for comparisons
  526. inputs_copy = copy.deepcopy(inputs)
  527. models[debug] = quantize_jit(
  528. model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False,
  529. debug=debug)
  530. # make sure it runs
  531. outputs[debug] = models[debug](*inputs[0])
  532. if debug:
  533. print('debug graph:', models[True].graph)
  534. print('non debug graph:', models[False].graph)
  535. if check:
  536. # debug and non-debug option should have the same numerics
  537. self.assertEqual(outputs[True], outputs[False])
  538. # non debug graph should produce quantized op
  539. FileCheck().check(quantized_op) \
  540. .run(models[False].graph)
  541. return models[False]
  542. def checkGraphModuleNodes(
  543. self, graph_module,
  544. expected_node=None,
  545. expected_node_occurrence=None,
  546. expected_node_list=None):
  547. """ Check if GraphModule contains the target node
  548. Args:
  549. graph_module: the GraphModule instance we want to check
  550. expected_node, expected_node_occurrence, expected_node_list:
  551. see docs for checkGraphModeFxOp
  552. """
  553. nodes_in_graph = {}
  554. node_list = []
  555. modules = dict(graph_module.named_modules(remove_duplicate=False))
  556. for node in graph_module.graph.nodes:
  557. n = None
  558. if node.op == 'call_function' or node.op == 'call_method':
  559. n = NodeSpec(node.op, node.target)
  560. elif node.op == 'call_module':
  561. n = NodeSpec(node.op, type(modules[node.target]))
  562. if n is not None:
  563. node_list.append(n)
  564. if n in nodes_in_graph:
  565. nodes_in_graph[n] += 1
  566. else:
  567. nodes_in_graph[n] = 1
  568. if expected_node is not None:
  569. self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) +
  570. ' not found in the graph module')
  571. if expected_node_occurrence is not None:
  572. for expected_node, occurrence in expected_node_occurrence.items():
  573. if occurrence != 0:
  574. self.assertTrue(
  575. expected_node in nodes_in_graph,
  576. 'Check failed for node:' + str(expected_node) +
  577. ' not found')
  578. self.assertTrue(
  579. nodes_in_graph[expected_node] == occurrence,
  580. 'Check failed for node:' + str(expected_node) +
  581. ' Expected occurrence:' + str(occurrence) +
  582. ' Found occurrence:' + str(nodes_in_graph[expected_node]))
  583. else:
  584. self.assertTrue(
  585. expected_node not in nodes_in_graph,
  586. 'Check failed for node:' + str(expected_node) +
  587. ' expected no occurrence but found')
  588. if expected_node_list is not None:
  589. cur_index = 0
  590. for n in node_list:
  591. if cur_index == len(expected_node_list):
  592. return
  593. if n == expected_node_list[cur_index]:
  594. cur_index += 1
  595. self.assertTrue(
  596. cur_index == len(expected_node_list),
  597. "Check failed for graph:" +
  598. self.printGraphModule(graph_module, print_str=False) +
  599. "Expected ordered list:" +
  600. str(expected_node_list))
  601. def printGraphModule(self, graph_module, print_str=True):
  602. modules = dict(graph_module.named_modules(remove_duplicate=False))
  603. node_infos = []
  604. for n in graph_module.graph.nodes:
  605. node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs]))
  606. if n.op == 'call_module':
  607. node_info += ' module type: ' + repr(type(modules[n.target]))
  608. node_infos.append(node_info)
  609. str_to_print = '\n'.join(node_infos)
  610. if print_str:
  611. print(str_to_print)
  612. return str_to_print
  613. if HAS_FX:
  614. def assert_types_for_matched_subgraph_pairs(
  615. self,
  616. matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
  617. expected_types: Dict[str, Tuple[Tuple[Callable, Callable], Tuple[Callable, Callable]]],
  618. gm_a: GraphModule,
  619. gm_b: GraphModule,
  620. ) -> None:
  621. """
  622. Verifies that the types specified in expected_types match
  623. the underlying objects pointed to by the nodes in matched_subgraph_pairs.
  624. An example successful test case:
  625. matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)}
  626. expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)}
  627. The function tests for key equivalence, and verifies types with
  628. instance checks.
  629. """
  630. def _get_underlying_op_type(
  631. node: Node, gm: GraphModule
  632. ) -> Union[Callable, str]:
  633. if node.op == 'call_module':
  634. mod = getattr(gm, node.target)
  635. return type(mod)
  636. else:
  637. assert node.op in ('call_function', 'call_method')
  638. return node.target
  639. self.assertTrue(
  640. len(matched_subgraph_pairs) == len(expected_types),
  641. 'Expected length of results to match, but got %d and %d' %
  642. (len(matched_subgraph_pairs), len(expected_types))
  643. )
  644. for k, v in expected_types.items():
  645. expected_types_a, expected_types_b = v
  646. exp_type_start_a, exp_type_end_a = expected_types_a
  647. exp_type_start_b, exp_type_end_b = expected_types_b
  648. subgraph_a, subgraph_b = matched_subgraph_pairs[k]
  649. act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a)
  650. act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b)
  651. act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a)
  652. act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b)
  653. types_match = (exp_type_start_a is act_type_start_a) and \
  654. (exp_type_end_a is act_type_end_a) and \
  655. (exp_type_start_b is act_type_start_b) and \
  656. (exp_type_end_b is act_type_end_b)
  657. self.assertTrue(
  658. types_match,
  659. 'Type mismatch at %s: expected %s, got %s' %
  660. (k, (exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b),
  661. (act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b))
  662. )
  663. def assert_ns_compare_dict_valid(
  664. self,
  665. act_compare_dict: Dict[str, Dict[str, Dict[str, Any]]],
  666. ) -> None:
  667. """
  668. Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
  669. 1. for each layer, results are recorded for two models
  670. 2. number of seen tensors match
  671. 3. shapes of each pair of seen tensors match
  672. """
  673. for layer_name, result_type_to_data in act_compare_dict.items():
  674. for result_type, layer_data in result_type_to_data.items():
  675. self.assertTrue(
  676. len(layer_data) == 2,
  677. f"Layer {layer_name} does not have exactly two model results.")
  678. model_name_0, model_name_1 = layer_data.keys()
  679. for res_idx in range(len(layer_data[model_name_0])):
  680. layer_data_0 = layer_data[model_name_0][res_idx]
  681. layer_data_1 = layer_data[model_name_1][res_idx]
  682. self.assertTrue(
  683. layer_data_0['type'] == layer_data_0['type'],
  684. f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.")
  685. self.assertTrue(
  686. len(layer_data_0['values']) ==
  687. len(layer_data_1['values']),
  688. f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.")
  689. # F.conv1d weight has rank 3, and toq.conv1d unpacked weight
  690. # has rank 4. For now, skip the length check for conv1d only.
  691. is_weight_functional_conv1d = (
  692. result_type == NSSingleResultValuesType.WEIGHT.value and
  693. (
  694. 'conv1d' in layer_data_0['prev_node_target_type'] or
  695. 'conv1d' in layer_data_1['prev_node_target_type']
  696. )
  697. )
  698. if not is_weight_functional_conv1d:
  699. for idx in range(len(layer_data_0['values'])):
  700. values_0 = layer_data_0['values'][idx]
  701. values_1 = layer_data_1['values'][idx]
  702. if isinstance(values_0, torch.Tensor):
  703. self.assertTrue(
  704. values_0.shape == values_1.shape,
  705. f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
  706. f"have a shape mismatch at idx {idx}.")
  707. elif isinstance(values_0, list):
  708. values_0 = values_0[0]
  709. values_1 = values_1[0]
  710. self.assertTrue(
  711. values_0.shape == values_1.shape,
  712. f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
  713. f"have a shape mismatch at idx {idx}.")
  714. else:
  715. assert isinstance(values_0, tuple), \
  716. f"unhandled type {type(values_0)}"
  717. assert len(values_0) == 2
  718. assert len(values_0[1]) == 2
  719. assert values_0[0].shape == values_1[0].shape
  720. assert values_0[1][0].shape == values_1[1][0].shape
  721. assert values_0[1][1].shape == values_1[1][1].shape
  722. # verify that ref_node_name is valid
  723. ref_node_name_0 = layer_data_0['ref_node_name']
  724. ref_node_name_1 = layer_data_1['ref_node_name']
  725. prev_node_name_0 = layer_data_0['prev_node_name']
  726. prev_node_name_1 = layer_data_1['prev_node_name']
  727. if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value:
  728. self.assertTrue(ref_node_name_0 == prev_node_name_0)
  729. self.assertTrue(ref_node_name_1 == prev_node_name_1)
  730. elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value:
  731. self.assertTrue(ref_node_name_0 != prev_node_name_0)
  732. self.assertTrue(ref_node_name_1 != prev_node_name_1)
  733. def checkGraphModeFxOp(
  734. self,
  735. model,
  736. inputs,
  737. quant_type,
  738. expected_node=None,
  739. expected_node_occurrence=None,
  740. expected_node_list=None,
  741. is_reference=False,
  742. print_debug_info=False,
  743. custom_qconfig_dict=None,
  744. prepare_expected_node=None,
  745. prepare_expected_node_occurrence=None,
  746. prepare_expected_node_list=None,
  747. prepare_custom_config=None,
  748. backend_config=None):
  749. """ Quantizes model with graph mode quantization on fx and check if the
  750. quantized model contains the quantized_node
  751. Args:
  752. model: floating point torch.nn.Module
  753. inputs: one positional sample input arguments for model
  754. expected_node: NodeSpec
  755. e.g. NodeSpec.call_function(torch.quantize_per_tensor)
  756. expected_node_occurrence: a dict from NodeSpec to
  757. expected number of occurences (int)
  758. e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
  759. NodeSpec.call_method('dequantize'): 1}
  760. expected_node_list: a list of NodeSpec, used to check the order
  761. of the occurrence of Node
  762. e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
  763. NodeSpec.call_module(nnq.Conv2d),
  764. NodeSpec.call_function(F.hardtanh_),
  765. NodeSpec.call_method('dequantize')]
  766. is_reference: if True, enables reference mode
  767. print_debug_info: if True, prints debug info
  768. custom_qconfig_dict: overrides default qconfig_dict
  769. prepare_expected_node: same as expected_node, but for prepare
  770. prepare_expected_node_occurrence: same as
  771. expected_node_occurrence, but for prepare
  772. prepare_expected_node_list: same as expected_node_list, but
  773. for prepare
  774. Returns:
  775. A dictionary with the following structure:
  776. {
  777. "prepared": ..., # the prepared model
  778. "quantized": ..., # the quantized non-reference model
  779. "quantized_reference": ..., # the quantized reference model
  780. "result": ..., # the result for either quantized or
  781. # quantized_reference model depending on the
  782. # is_reference arguemnt
  783. }
  784. """
  785. # TODO: make img_data a single example instead of a list
  786. if type(inputs) == list:
  787. inputs = inputs[0]
  788. if quant_type == QuantType.QAT:
  789. qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
  790. model.train()
  791. elif quant_type == QuantType.STATIC:
  792. qconfig = get_default_qconfig(torch.backends.quantized.engine)
  793. model.eval()
  794. else:
  795. qconfig = default_dynamic_qconfig
  796. model.eval()
  797. if quant_type == QuantType.QAT:
  798. prepare = prepare_qat_fx
  799. else:
  800. prepare = prepare_fx
  801. qconfig_dict = {"": qconfig}
  802. # overwrite qconfig_dict with custom_qconfig_dict
  803. if custom_qconfig_dict is not None:
  804. qconfig_dict = custom_qconfig_dict
  805. prepared = prepare(
  806. model, qconfig_dict,
  807. example_inputs=inputs,
  808. prepare_custom_config=prepare_custom_config,
  809. backend_config=backend_config)
  810. if not quant_type == QuantType.DYNAMIC:
  811. prepared(*inputs)
  812. if print_debug_info:
  813. print()
  814. print('quant type:\n', quant_type)
  815. print('original model:\n', model)
  816. print()
  817. print('prepared model:\n', prepared)
  818. self.checkGraphModuleNodes(
  819. prepared, prepare_expected_node,
  820. prepare_expected_node_occurrence, prepare_expected_node_list)
  821. prepared_copy = copy.deepcopy(prepared)
  822. qgraph = convert_fx(copy.deepcopy(prepared))
  823. qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared))
  824. result = qgraph(*inputs)
  825. result_reference = qgraph_reference(*inputs)
  826. qgraph_copy = copy.deepcopy(qgraph)
  827. qgraph_reference_copy = copy.deepcopy(qgraph_reference)
  828. qgraph_to_check = qgraph_reference if is_reference else qgraph
  829. if print_debug_info:
  830. print()
  831. print('quantized model:\n', qgraph_to_check)
  832. self.printGraphModule(qgraph_to_check)
  833. print()
  834. self.checkGraphModuleNodes(
  835. qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
  836. return {"prepared": prepared_copy,
  837. "quantized": qgraph_copy,
  838. "quantized_reference": qgraph_reference_copy,
  839. "quantized_output": result,
  840. "quantized_reference_output": result_reference}
  841. def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets,
  842. set_qconfig, is_emb_bag, dtype=torch.quint8):
  843. # Test serialization of dynamic EmbeddingBag module using state_dict
  844. if is_emb_bag:
  845. inputs = [indices, offsets]
  846. else:
  847. inputs = [indices]
  848. emb_dict = qemb.state_dict()
  849. b = io.BytesIO()
  850. torch.save(emb_dict, b)
  851. b.seek(0)
  852. loaded_dict = torch.load(b)
  853. embedding_unpack = torch.ops.quantized.embedding_bag_unpack
  854. # Check unpacked weight values explicitly
  855. for key in emb_dict:
  856. if isinstance(emb_dict[key], torch._C.ScriptObject):
  857. assert isinstance(loaded_dict[key], torch._C.ScriptObject)
  858. emb_weight = embedding_unpack(emb_dict[key])
  859. loaded_weight = embedding_unpack(loaded_dict[key])
  860. self.assertEqual(emb_weight, loaded_weight)
  861. # Check state dict serialization and torch.save APIs
  862. if is_emb_bag:
  863. loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
  864. include_last_offset=True, mode='sum', dtype=dtype)
  865. else:
  866. loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
  867. self.check_eager_serialization(qemb, loaded_qemb, inputs)
  868. loaded_qemb.load_state_dict(loaded_dict)
  869. self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight),
  870. embedding_unpack(loaded_qemb._packed_params._packed_weight))
  871. # Test JIT serialization
  872. self.checkScriptable(qemb, [inputs], check_save_load=True)
  873. # Test from_float call
  874. if is_emb_bag:
  875. float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
  876. include_last_offset=True, scale_grad_by_freq=False, mode='sum')
  877. else:
  878. float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
  879. if set_qconfig:
  880. float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
  881. qscheme=torch.per_channel_affine_float_qparams,
  882. ch_axis=0)
  883. float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer,
  884. weight=float_qparams_observer)
  885. prepare_dynamic(float_embedding)
  886. float_embedding(*inputs)
  887. if is_emb_bag:
  888. q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding)
  889. expected_name = "QuantizedEmbeddingBag"
  890. else:
  891. q_embeddingbag = nnq.Embedding.from_float(float_embedding)
  892. expected_name = "QuantizedEmbedding"
  893. q_embeddingbag(*inputs)
  894. self.assertTrue(expected_name in str(q_embeddingbag))
  895. class QuantizationLiteTestCase(QuantizationTestCase):
  896. def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs):
  897. # Creates quantized model for testing mobile script modules
  898. qengine = "qnnpack"
  899. with override_quantized_engine(qengine):
  900. qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  901. model = model_class(**kwargs)
  902. model = quantize(model, test_only_eval_fn, [self.calib_data])
  903. return model
  904. def _compare_script_and_mobile(self,
  905. model: torch.nn.Module,
  906. input: torch.Tensor):
  907. # Compares the numerical outputs for script and lite modules
  908. qengine = "qnnpack"
  909. with override_quantized_engine(qengine):
  910. script_module = torch.jit.script(model)
  911. script_module_result = script_module(input)
  912. max_retry = 5
  913. for retry in range(1, max_retry + 1):
  914. # retries `max_retry` times; breaks iff succeeds else throws exception
  915. try:
  916. buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
  917. buffer.seek(0)
  918. mobile_module = _load_for_lite_interpreter(buffer)
  919. mobile_module_result = mobile_module(input)
  920. torch.testing.assert_close(script_module_result, mobile_module_result)
  921. mobile_module_forward_result = mobile_module.forward(input)
  922. torch.testing.assert_close(script_module_result, mobile_module_forward_result)
  923. mobile_module_run_method_result = mobile_module.run_method("forward", input)
  924. torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
  925. except AssertionError as e:
  926. if retry == max_retry:
  927. raise e
  928. else:
  929. continue
  930. break
  931. # Below are a series of toy models to use in testing quantization
  932. class SingleLayerLinearModel(torch.nn.Module):
  933. def __init__(self):
  934. super().__init__()
  935. self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  936. def forward(self, x):
  937. x = self.fc1(x)
  938. return x
  939. def get_example_inputs(self) -> Tuple[Any, ...]:
  940. return (torch.rand(1, 5),)
  941. class AnnotatedSingleLayerLinearModel(torch.nn.Module):
  942. def __init__(self, qengine='fbgemm'):
  943. super().__init__()
  944. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  945. self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  946. def forward(self, x):
  947. x = self.fc1(x)
  948. return x
  949. def get_example_inputs(self) -> Tuple[Any, ...]:
  950. return (torch.rand(1, 5),)
  951. class SingleLayerLinearDynamicModel(torch.nn.Module):
  952. def __init__(self, qengine='fbgemm'):
  953. super().__init__()
  954. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  955. self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  956. def forward(self, x):
  957. x = self.fc1(x)
  958. return x
  959. def get_example_inputs(self) -> Tuple[Any, ...]:
  960. return (torch.rand(1, 5),)
  961. class LinearAddModel(nn.Module):
  962. def __init__(self):
  963. super().__init__()
  964. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  965. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  966. def forward(self, x):
  967. x = self.fc1(x)
  968. x = torch.add(x, 5)
  969. x = self.fc2(x)
  970. return x
  971. def get_example_inputs(self) -> Tuple[Any, ...]:
  972. return (torch.rand(1, 5),)
  973. class RNNDynamicModel(torch.nn.Module):
  974. def __init__(self, mod_type):
  975. super().__init__()
  976. self.qconfig = default_dynamic_qconfig
  977. if mod_type == 'GRU':
  978. self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
  979. if mod_type == 'LSTM':
  980. self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
  981. def forward(self, x):
  982. x = self.mod(x)
  983. return x
  984. class RNNCellDynamicModel(torch.nn.Module):
  985. def __init__(self, mod_type):
  986. super().__init__()
  987. self.qconfig = default_dynamic_qconfig
  988. if mod_type == 'GRUCell':
  989. self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float)
  990. if mod_type == 'LSTMCell':
  991. self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float)
  992. if mod_type == 'RNNReLU':
  993. self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float)
  994. if mod_type == 'RNNTanh':
  995. self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float)
  996. def forward(self, x):
  997. x = self.mod(x)
  998. return x
  999. class LSTMwithHiddenDynamicModel(torch.nn.Module):
  1000. def __init__(self, qengine='fbgemm'):
  1001. super().__init__()
  1002. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1003. self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)
  1004. def forward(self, x, hid):
  1005. x, hid = self.lstm(x, hid)
  1006. return x, hid
  1007. class ConvModel(torch.nn.Module):
  1008. def __init__(self):
  1009. super().__init__()
  1010. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1011. def forward(self, x):
  1012. x = self.conv(x)
  1013. return x
  1014. def get_example_inputs(self) -> Tuple[Any, ...]:
  1015. return (torch.rand(1, 3, 5, 5),)
  1016. class ConvTransposeModel(torch.nn.Module):
  1017. def __init__(self):
  1018. super().__init__()
  1019. self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1020. def forward(self, x):
  1021. x = self.conv(x)
  1022. return x
  1023. def get_example_inputs(self) -> Tuple[Any, ...]:
  1024. return (torch.rand(1, 3, 5, 5),)
  1025. class AnnotatedConvModel(torch.nn.Module):
  1026. def __init__(self, qengine):
  1027. super().__init__()
  1028. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1029. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1030. self.quant = QuantStub()
  1031. self.dequant = DeQuantStub()
  1032. def forward(self, x):
  1033. x = self.quant(x)
  1034. x = self.conv(x)
  1035. x = self.dequant(x)
  1036. return x
  1037. def get_example_inputs(self) -> Tuple[Any, ...]:
  1038. return (torch.rand(1, 3, 5, 5),)
  1039. class AnnotatedConvTransposeModel(torch.nn.Module):
  1040. def __init__(self, qengine):
  1041. super().__init__()
  1042. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1043. self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1044. self.quant = QuantStub()
  1045. self.dequant = DeQuantStub()
  1046. def forward(self, x):
  1047. x = self.quant(x)
  1048. x = self.conv(x)
  1049. x = self.dequant(x)
  1050. return x
  1051. def get_example_inputs(self) -> Tuple[Any, ...]:
  1052. return (torch.rand(1, 3, 5, 5),)
  1053. class ConvBnModel(torch.nn.Module):
  1054. def __init__(self):
  1055. super().__init__()
  1056. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1057. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1058. def forward(self, x):
  1059. x = self.conv(x)
  1060. x = self.bn(x)
  1061. return x
  1062. def get_example_inputs(self) -> Tuple[Any, ...]:
  1063. return (torch.rand(1, 3, 5, 5),)
  1064. class AnnotatedConvBnModel(torch.nn.Module):
  1065. def __init__(self):
  1066. super().__init__()
  1067. self.qconfig = default_qconfig
  1068. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1069. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1070. self.quant = QuantStub()
  1071. self.dequant = DeQuantStub()
  1072. def forward(self, x):
  1073. x = self.quant(x)
  1074. x = self.conv(x)
  1075. x = self.bn(x)
  1076. x = self.dequant(x)
  1077. return x
  1078. def get_example_inputs(self) -> Tuple[Any, ...]:
  1079. return (torch.rand(1, 3, 5, 5),)
  1080. class ConvBnReLUModel(torch.nn.Module):
  1081. def __init__(self):
  1082. super().__init__()
  1083. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1084. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1085. self.relu = nn.ReLU(inplace=True)
  1086. def forward(self, x):
  1087. x = self.conv(x)
  1088. x = self.bn(x)
  1089. x = self.relu(x)
  1090. return x
  1091. def get_example_inputs(self) -> Tuple[Any, ...]:
  1092. return (torch.rand(1, 3, 5, 5),)
  1093. class AnnotatedConvBnReLUModel(torch.nn.Module):
  1094. def __init__(self, qengine='fbgemm'):
  1095. super().__init__()
  1096. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1097. self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1098. self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
  1099. self.relu = nn.ReLU(inplace=True)
  1100. self.quant = QuantStub()
  1101. self.dequant = DeQuantStub()
  1102. def forward(self, x):
  1103. x = self.quant(x)
  1104. x = self.conv(x)
  1105. x = self.bn(x)
  1106. x = self.relu(x)
  1107. x = self.dequant(x)
  1108. return x
  1109. def fuse_model(self):
  1110. # TODO: remove this check and define two fuse_modules function on this module
  1111. if self.training:
  1112. torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
  1113. else:
  1114. torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
  1115. def get_example_inputs(self) -> Tuple[Any, ...]:
  1116. return (torch.rand(1, 3, 5, 5),)
  1117. class TwoLayerConvModel(torch.nn.Module):
  1118. def __init__(self):
  1119. super().__init__()
  1120. self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
  1121. self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float)
  1122. def forward(self, x):
  1123. x = self.conv1(x)
  1124. x = self.conv2(x)
  1125. return x
  1126. def get_example_inputs(self) -> Tuple[Any, ...]:
  1127. return (torch.rand(1, 3, 5, 5),)
  1128. class TwoLayerLinearModel(torch.nn.Module):
  1129. def __init__(self):
  1130. super().__init__()
  1131. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1132. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1133. def forward(self, x):
  1134. x = self.fc1(x)
  1135. x = self.fc2(x)
  1136. return x
  1137. def get_example_inputs(self) -> Tuple[Any, ...]:
  1138. return (torch.rand(1, 5),)
  1139. class LinearModelWithSubmodule(nn.Module):
  1140. def __init__(self):
  1141. super().__init__()
  1142. self.subm = TwoLayerLinearModel()
  1143. self.fc = nn.Linear(5, 5)
  1144. def forward(self, x):
  1145. x = self.subm(x)
  1146. x = self.fc(x)
  1147. return x
  1148. def get_example_inputs(self) -> Tuple[Any, ...]:
  1149. return self.subm.get_example_inputs()
  1150. class AnnotatedTwoLayerLinearModel(torch.nn.Module):
  1151. def __init__(self):
  1152. super().__init__()
  1153. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1154. self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
  1155. self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
  1156. def forward(self, x):
  1157. x = self.fc1(x)
  1158. x = self.fc2(x)
  1159. return x
  1160. def get_example_inputs(self) -> Tuple[Any, ...]:
  1161. return (torch.rand(1, 5),)
  1162. class ActivationsTestModel(torch.nn.Module):
  1163. def __init__(self):
  1164. super().__init__()
  1165. self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
  1166. self.quant = torch.ao.quantization.QuantStub()
  1167. self.hardswish = torch.nn.Hardswish().to(dtype=torch.float)
  1168. self.elu = torch.nn.ELU().to(dtype=torch.float)
  1169. self.dequant = torch.ao.quantization.DeQuantStub()
  1170. def forward(self, x):
  1171. x = self.quant(x)
  1172. x = self.hardswish(x)
  1173. x = self.elu(x)
  1174. x = self.dequant(x)
  1175. return x
  1176. class LinearReluModel(torch.nn.Module):
  1177. def __init__(self):
  1178. super().__init__()
  1179. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1180. self.relu = torch.nn.ReLU()
  1181. def forward(self, x):
  1182. x = self.relu(self.fc(x))
  1183. return x
  1184. def get_example_inputs(self) -> Tuple[Any, ...]:
  1185. return (torch.rand(1, 5),)
  1186. class LinearReluLinearModel(torch.nn.Module):
  1187. def __init__(self):
  1188. super().__init__()
  1189. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1190. self.relu = torch.nn.ReLU()
  1191. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1192. def forward(self, x):
  1193. x = self.fc1(x)
  1194. x = self.relu(x)
  1195. x = self.fc2(x)
  1196. return x
  1197. def get_example_inputs(self) -> Tuple[Any, ...]:
  1198. return (torch.rand(1, 5),)
  1199. class LinearReluAddModel(torch.nn.Module):
  1200. def __init__(self):
  1201. super().__init__()
  1202. self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1203. self.relu = torch.nn.ReLU()
  1204. self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1205. def forward(self, x):
  1206. x = self.fc1(x)
  1207. x = self.relu(x)
  1208. x = torch.add(x, 5)
  1209. x = self.fc2(x)
  1210. self.relu = torch.nn.ReLU()
  1211. return x
  1212. def get_example_inputs(self) -> Tuple[Any, ...]:
  1213. return (torch.rand(1, 5),)
  1214. class LinearBnLeakyReluModel(torch.nn.Module):
  1215. def __init__(self, with_bn=True):
  1216. super().__init__()
  1217. self.linear = nn.Linear(5, 5)
  1218. self.bn1d = nn.BatchNorm1d(5)
  1219. self.leaky_relu = nn.LeakyReLU(0.01)
  1220. self.with_bn = with_bn
  1221. def forward(self, x):
  1222. x = self.linear(x)
  1223. if self.with_bn:
  1224. x = self.bn1d(x)
  1225. x = self.leaky_relu(x)
  1226. return x
  1227. def get_example_inputs(self) -> Tuple[Any, ...]:
  1228. return (torch.rand(1, 5),)
  1229. class LinearTanhModel(torch.nn.Module):
  1230. def __init__(self):
  1231. super().__init__()
  1232. self.linear = nn.Linear(5, 5)
  1233. self.tanh = nn.Tanh()
  1234. def forward(self, x):
  1235. x = self.linear(x)
  1236. x = self.tanh(x)
  1237. return x
  1238. def get_example_inputs(self) -> Tuple[Any, ...]:
  1239. return (torch.rand(1, 5),)
  1240. class ConvBnAddReluModel(torch.nn.Module):
  1241. def __init__(self,
  1242. with_bn=True,
  1243. with_relu=True,
  1244. left_conv=True,
  1245. two_conv=True,
  1246. use_torch_add=True):
  1247. super().__init__()
  1248. self.conv = nn.Conv2d(5, 5, (2, 2))
  1249. self.conv2 = nn.Conv2d(5, 5, (2, 2))
  1250. self.bn = nn.BatchNorm2d(5)
  1251. self.relu = nn.ReLU()
  1252. self.with_bn = with_bn
  1253. self.with_relu = with_relu
  1254. self.two_conv = two_conv
  1255. self.left_conv = left_conv
  1256. self.use_torch_add = use_torch_add
  1257. def forward(self, x1, x2):
  1258. if self.two_conv:
  1259. if self.use_torch_add:
  1260. if self.with_bn:
  1261. x = torch.add(self.bn(self.conv(x1)), self.conv2(x1))
  1262. else:
  1263. x = torch.add(self.conv(x1), self.conv2(x1))
  1264. else:
  1265. if self.with_bn:
  1266. x = self.bn(self.conv(x1)) + self.conv2(x1)
  1267. else:
  1268. x = self.conv(x1) + self.conv2(x1)
  1269. else:
  1270. if self.use_torch_add:
  1271. if self.left_conv:
  1272. if self.with_bn:
  1273. x = torch.add(self.bn(self.conv(x1)), x2)
  1274. else:
  1275. x = torch.add(self.conv(x1), x2)
  1276. else:
  1277. if self.with_bn:
  1278. x = torch.add(x2, self.bn(self.conv(x1)))
  1279. else:
  1280. x = torch.add(x2, self.conv(x1))
  1281. else:
  1282. if self.left_conv:
  1283. if self.with_bn:
  1284. x = self.bn(self.conv(x1)) + x2
  1285. else:
  1286. x = self.conv(x1) + x2
  1287. else:
  1288. if self.with_bn:
  1289. x = x2 + self.bn(self.conv(x1))
  1290. else:
  1291. x = x2 + self.conv(x1)
  1292. if self.with_relu:
  1293. x = self.relu(x)
  1294. return x
  1295. def get_example_inputs(self) -> Tuple[Any, ...]:
  1296. return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2))
  1297. # TODO: self.fc should be self.conv
  1298. class ConvReluModel(torch.nn.Module):
  1299. def __init__(self):
  1300. super().__init__()
  1301. self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
  1302. self.relu = torch.nn.ReLU()
  1303. def forward(self, x):
  1304. x = self.relu(self.fc(x))
  1305. return x
  1306. def get_example_inputs(self) -> Tuple[Any, ...]:
  1307. return (torch.rand(1, 3, 5, 5),)
  1308. # TODO: self.fc should be self.conv
  1309. class ConvReluConvModel(torch.nn.Module):
  1310. def __init__(self):
  1311. super().__init__()
  1312. self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
  1313. self.relu = torch.nn.ReLU()
  1314. self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
  1315. def forward(self, x):
  1316. x = self.fc1(x)
  1317. x = self.relu(x)
  1318. x = self.fc2(x)
  1319. return x
  1320. def get_example_inputs(self) -> Tuple[Any, ...]:
  1321. return (torch.rand(1, 3, 5, 5),)
  1322. # TODO: self.fc should be self.conv
  1323. class ConvReluAddModel(torch.nn.Module):
  1324. def __init__(self):
  1325. super().__init__()
  1326. self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
  1327. self.relu = torch.nn.ReLU()
  1328. self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
  1329. def forward(self, x):
  1330. x = self.fc1(x)
  1331. x = self.relu(x)
  1332. x = torch.add(x, 5)
  1333. x = self.fc2(x)
  1334. self.relu = torch.nn.ReLU()
  1335. return x
  1336. def get_example_inputs(self) -> Tuple[Any, ...]:
  1337. return (torch.rand(1, 3, 5, 5),)
  1338. class NormalizationTestModel(torch.nn.Module):
  1339. def __init__(self):
  1340. super().__init__()
  1341. self.quant = torch.ao.quantization.QuantStub()
  1342. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1343. self.layer_norm = torch.nn.LayerNorm((8))
  1344. self.group_norm = torch.nn.GroupNorm(2, 8)
  1345. self.instance_norm1d = torch.nn.InstanceNorm1d(8)
  1346. self.instance_norm2d = torch.nn.InstanceNorm2d(8)
  1347. self.instance_norm3d = torch.nn.InstanceNorm3d(8)
  1348. def forward(self, x):
  1349. x = self.quant(x)
  1350. x = self.fc1(x)
  1351. x = self.layer_norm(x)
  1352. x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3))
  1353. x = self.instance_norm1d(x)
  1354. x = self.instance_norm2d(x.unsqueeze(-1))
  1355. x = self.instance_norm3d(x.unsqueeze(-1))
  1356. return x
  1357. class NestedModel(torch.nn.Module):
  1358. def __init__(self):
  1359. super().__init__()
  1360. self.sub1 = LinearReluModel()
  1361. self.sub2 = TwoLayerLinearModel()
  1362. self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1363. def forward(self, x):
  1364. x = self.sub1(x)
  1365. x = self.sub2(x)
  1366. x = self.fc3(x)
  1367. return x
  1368. class AnnotatedNestedModel(torch.nn.Module):
  1369. def __init__(self, qengine):
  1370. super().__init__()
  1371. self.sub1 = LinearReluModel()
  1372. self.sub2 = TwoLayerLinearModel()
  1373. self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1374. self.fc3.qconfig = default_qconfig
  1375. self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
  1376. if qengine == 'fbgemm':
  1377. self.sub2.fc1.qconfig = default_per_channel_qconfig
  1378. else:
  1379. self.sub2.fc1.qconfig = default_qconfig
  1380. def forward(self, x):
  1381. x = self.sub1(x)
  1382. x = self.sub2(x)
  1383. x = self.fc3(x)
  1384. return x
  1385. class AnnotatedSubNestedModel(torch.nn.Module):
  1386. def __init__(self):
  1387. super().__init__()
  1388. self.sub1 = LinearReluModel()
  1389. self.sub2 = QuantWrapper(TwoLayerLinearModel())
  1390. self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1391. self.fc3.qconfig = default_qconfig
  1392. self.sub2.qconfig = default_qconfig
  1393. def forward(self, x):
  1394. x = self.sub1(x)
  1395. x = self.sub2(x)
  1396. x = self.fc3(x)
  1397. return x
  1398. class AnnotatedCustomConfigNestedModel(torch.nn.Module):
  1399. def __init__(self):
  1400. super().__init__()
  1401. self.sub1 = LinearReluModel()
  1402. self.sub2 = TwoLayerLinearModel()
  1403. self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
  1404. self.fc3.qconfig = default_qconfig
  1405. self.sub2.qconfig = default_qconfig
  1406. custom_options = {
  1407. 'dtype': torch.quint8,
  1408. 'qscheme': torch.per_tensor_affine
  1409. }
  1410. custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options),
  1411. weight=default_weight_observer)
  1412. self.sub2.fc1.qconfig = custom_qconfig
  1413. self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
  1414. self.sub2.fc2 = QuantWrapper(self.sub2.fc2)
  1415. def forward(self, x):
  1416. x = self.sub1(x)
  1417. x = self.sub2(x)
  1418. x = self.fc3(x)
  1419. return x
  1420. class QuantSubModel(torch.nn.Module):
  1421. def __init__(self):
  1422. super().__init__()
  1423. self.sub1 = LinearReluModel()
  1424. self.sub2 = QuantWrapper(TwoLayerLinearModel())
  1425. self.sub2.qconfig = default_qconfig
  1426. self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1427. self.fc3.qconfig = default_qconfig
  1428. def forward(self, x):
  1429. x = self.sub1(x)
  1430. x = self.sub2(x)
  1431. x = self.fc3(x)
  1432. return x
  1433. class InnerModule(torch.nn.Module):
  1434. def __init__(self):
  1435. super().__init__()
  1436. self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
  1437. self.relu1 = torch.nn.ReLU()
  1438. self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
  1439. self.relu2 = torch.nn.ReLU()
  1440. def forward(self, x):
  1441. return self.relu2(self.fc2(self.relu1(self.fc1(x))))
  1442. def fuse_modules(self):
  1443. fusable_layers = []
  1444. named_children = list(self.named_children())
  1445. for idx, (current_name, layer) in enumerate(named_children):
  1446. if isinstance(layer, torch.nn.Linear):
  1447. if idx >= len(named_children) - 1:
  1448. break
  1449. if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
  1450. fusable_layers.append([current_name,
  1451. named_children[idx + 1][0]])
  1452. # TODO: remove this check and define two fuse_modules function on this module
  1453. if self.training:
  1454. torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
  1455. else:
  1456. torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
  1457. class FunctionalLinear(torch.nn.Module):
  1458. def __init__(self):
  1459. super().__init__()
  1460. self.weight = torch.rand((5, 5))
  1461. self.bias = torch.zeros(5)
  1462. def forward(self, x):
  1463. return F.linear(x, self.weight, self.bias)
  1464. def get_example_inputs(self) -> Tuple[Any, ...]:
  1465. return (torch.rand(1, 5),)
  1466. class SingleLayerFunctionalLinearModel(torch.nn.Module):
  1467. def __init__(self):
  1468. super().__init__()
  1469. self.linear1 = FunctionalLinear()
  1470. def forward(self, x):
  1471. x = self.linear1(x)
  1472. return x
  1473. def get_example_inputs(self) -> Tuple[Any, ...]:
  1474. return self.linear1.get_example_inputs()
  1475. class TwoLayerFunctionalLinearModel(torch.nn.Module):
  1476. def __init__(self):
  1477. super().__init__()
  1478. self.linear1 = FunctionalLinear()
  1479. self.linear2 = FunctionalLinear()
  1480. def forward(self, x):
  1481. x = self.linear1(x)
  1482. x = self.linear2(x)
  1483. return x
  1484. def get_example_inputs(self) -> Tuple[Any, ...]:
  1485. return self.linear1.get_example_inputs()
  1486. class FunctionalLinearAddModel(torch.nn.Module):
  1487. def __init__(self):
  1488. super().__init__()
  1489. self.linear1 = FunctionalLinear()
  1490. self.linear2 = FunctionalLinear()
  1491. def forward(self, x):
  1492. x = self.linear1(x)
  1493. x = torch.add(x, 5)
  1494. x = self.linear2(x)
  1495. return x
  1496. def get_example_inputs(self) -> Tuple[Any, ...]:
  1497. return self.linear1.get_example_inputs()
  1498. class FunctionalLinearReluModel(nn.Module):
  1499. def __init__(self):
  1500. super().__init__()
  1501. self.linear = FunctionalLinear()
  1502. def forward(self, x):
  1503. x = self.linear(x)
  1504. x = F.relu(x)
  1505. return x
  1506. def get_example_inputs(self) -> Tuple[Any, ...]:
  1507. return self.linear.get_example_inputs()
  1508. class FunctionalLinearReluLinearModel(nn.Module):
  1509. def __init__(self):
  1510. super().__init__()
  1511. self.linear1 = FunctionalLinear()
  1512. self.relu = nn.ReLU()
  1513. self.linear2 = FunctionalLinear()
  1514. def forward(self, x):
  1515. x = self.linear1(x)
  1516. x = self.relu(x)
  1517. x = self.linear2(x)
  1518. return x
  1519. def get_example_inputs(self) -> Tuple[Any, ...]:
  1520. return self.linear1.get_example_inputs()
  1521. class FunctionalConv2d(torch.nn.Module):
  1522. def __init__(self):
  1523. super().__init__()
  1524. self.weight = torch.rand(3, 3, 3, 3)
  1525. self.bias = torch.rand(3)
  1526. self.stride = (1, 1)
  1527. self.padding = (0, 0)
  1528. self.dilation = (1, 1)
  1529. self.groups = 1
  1530. def forward(self, x):
  1531. return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  1532. def get_example_inputs(self) -> Tuple[Any, ...]:
  1533. return (torch.rand(1, 3, 5, 5),)
  1534. class SingleLayerFunctionalConvModel(torch.nn.Module):
  1535. def __init__(self):
  1536. super().__init__()
  1537. self.conv1 = FunctionalConv2d()
  1538. def forward(self, x):
  1539. x = self.conv1(x)
  1540. return x
  1541. def get_example_inputs(self) -> Tuple[Any, ...]:
  1542. return self.conv1.get_example_inputs()
  1543. class TwoLayerFunctionalConvModel(torch.nn.Module):
  1544. def __init__(self):
  1545. super().__init__()
  1546. self.conv1 = FunctionalConv2d()
  1547. self.conv2 = FunctionalConv2d()
  1548. def forward(self, x):
  1549. x = self.conv1(x)
  1550. x = self.conv2(x)
  1551. return x
  1552. def get_example_inputs(self) -> Tuple[Any, ...]:
  1553. return self.conv1.get_example_inputs()
  1554. class FunctionalConvReluModel(nn.Module):
  1555. def __init__(self):
  1556. super().__init__()
  1557. self.conv = FunctionalConv2d()
  1558. def forward(self, x):
  1559. x = self.conv(x)
  1560. x = F.relu(x)
  1561. return x
  1562. def get_example_inputs(self) -> Tuple[Any, ...]:
  1563. return self.conv.get_example_inputs()
  1564. class FunctionalConvReluConvModel(nn.Module):
  1565. def __init__(self):
  1566. super().__init__()
  1567. self.conv1 = FunctionalConv2d()
  1568. self.relu = nn.ReLU()
  1569. self.conv2 = FunctionalConv2d()
  1570. def forward(self, x):
  1571. x = self.conv1(x)
  1572. x = self.relu(x)
  1573. x = self.conv2(x)
  1574. return x
  1575. def get_example_inputs(self) -> Tuple[Any, ...]:
  1576. return self.conv1.get_example_inputs()
  1577. class SkipQuantModel(torch.nn.Module):
  1578. r"""We can skip quantization by explicitly
  1579. setting qconfig of a submodule to None
  1580. """
  1581. def __init__(self):
  1582. super().__init__()
  1583. self.sub = InnerModule()
  1584. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1585. def forward(self, x):
  1586. return self.fc(self.sub(x))
  1587. def fuse_modules(self):
  1588. self.sub.fuse_modules()
  1589. class AnnotatedSkipQuantModel(torch.nn.Module):
  1590. r"""We can skip quantization by explicitly
  1591. setting qconfig of a submodule to None
  1592. """
  1593. def __init__(self, qengine):
  1594. super().__init__()
  1595. self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
  1596. self.sub = QuantWrapper(InnerModule())
  1597. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1598. # don't quantize this fc
  1599. self.fc.qconfig = None
  1600. def forward(self, x):
  1601. return self.fc(self.sub(x))
  1602. def fuse_modules(self):
  1603. self.sub.module.fuse_modules()
  1604. class QuantStubModel(torch.nn.Module):
  1605. r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
  1606. """
  1607. def __init__(self):
  1608. super().__init__()
  1609. self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
  1610. self.quant = QuantStub()
  1611. self.dequant = DeQuantStub()
  1612. self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
  1613. def forward(self, x):
  1614. x = self.quant(x)
  1615. x = self.fc(x)
  1616. return self.dequant(x)
  1617. class ManualLinearQATModel(torch.nn.Module):
  1618. r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
  1619. """
  1620. def __init__(self, qengine):
  1621. super().__init__()
  1622. self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
  1623. self.quant = QuantStub()
  1624. self.dequant = DeQuantStub()
  1625. self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
  1626. self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
  1627. def forward(self, x):
  1628. x = self.quant(x)
  1629. x = self.fc1(x)
  1630. x = self.fc2(x)
  1631. return self.dequant(x)
  1632. class ManualDropoutQATModel(torch.nn.Module):
  1633. r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
  1634. """
  1635. def __init__(self, qengine):
  1636. super().__init__()
  1637. self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
  1638. self.quant = QuantStub()
  1639. self.dequant = DeQuantStub()
  1640. self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
  1641. self.dropout = torch.nn.Dropout(0.5)
  1642. def forward(self, x):
  1643. x = self.quant(x)
  1644. x = self.fc1(x)
  1645. x = self.dropout(x)
  1646. return self.dequant(x)
  1647. class ManualLinearDynamicQATModel(torch.nn.Module):
  1648. r"""A Module that uses a dynamic QAT by default.
  1649. """
  1650. def __init__(self, qconfig=None):
  1651. super().__init__()
  1652. self.qconfig = qconfig or default_dynamic_qat_qconfig
  1653. self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
  1654. self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
  1655. def forward(self, x):
  1656. x = self.fc1(x)
  1657. x = self.fc2(x)
  1658. return x
  1659. class ManualConvLinearQATModel(torch.nn.Module):
  1660. r"""A module with manually inserted `QuantStub` and `DeQuantStub`
  1661. and contains both linear and conv modules
  1662. """
  1663. def __init__(self, qconfig=None):
  1664. super().__init__()
  1665. self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack")
  1666. self.quant = QuantStub()
  1667. self.dequant = DeQuantStub()
  1668. self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
  1669. self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float)
  1670. self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float)
  1671. def forward(self, x):
  1672. x = self.quant(x)
  1673. x = self.conv(x)
  1674. x = x.view(-1, 64).contiguous()
  1675. x = self.fc1(x)
  1676. x = self.fc2(x)
  1677. return self.dequant(x)
  1678. class ManualConvLinearSymmQATModel(ManualConvLinearQATModel):
  1679. r"""Same as ManualConvLinearQATModule but with Symmetric Quantization.
  1680. Supported only with qnnpack.
  1681. """
  1682. def __init__(self):
  1683. super().__init__(default_symmetric_qnnpack_qat_qconfig)
  1684. class ManualEmbeddingBagLinear(nn.Module):
  1685. def __init__(self):
  1686. super().__init__()
  1687. self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
  1688. self.emb.qconfig = default_embedding_qat_qconfig
  1689. self.quant = QuantStub()
  1690. self.dequant = DeQuantStub()
  1691. self.linear = nn.Linear(12, 1).to(dtype=torch.float)
  1692. self.qconfig = get_default_qat_qconfig("qnnpack")
  1693. def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
  1694. per_sample_weights: Optional[torch.Tensor] = None):
  1695. x = self.emb(input, offsets, per_sample_weights)
  1696. x = self.quant(x)
  1697. x = self.linear(x)
  1698. return self.dequant(x)
  1699. class DeFusedEmbeddingBagLinear(nn.Module):
  1700. r"""A module to simulate QAT embedding bag with a linear layer,
  1701. this module uses a separate embedding and bagging op, similar
  1702. to that which is described in the EmbeddingBag documentation.
  1703. https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
  1704. """
  1705. def __init__(self) -> None:
  1706. super().__init__()
  1707. self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12)
  1708. self.emb.qconfig = default_embedding_qat_qconfig
  1709. self.bagging_op = torch.sum
  1710. self.quant = QuantStub()
  1711. self.dequant = DeQuantStub()
  1712. self.linear = nn.Linear(12, 1).to(dtype=torch.float)
  1713. self.qconfig = get_default_qat_qconfig("qnnpack")
  1714. def forward(self, input: torch.Tensor) -> torch.Tensor:
  1715. x = self.bagging_op(self.emb(input), dim=1)
  1716. x = self.quant(x)
  1717. x = self.linear(x)
  1718. return self.dequant(x)
  1719. class SubModelForFusion(nn.Module):
  1720. def __init__(self):
  1721. super().__init__()
  1722. self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
  1723. self.bn = nn.BatchNorm2d(2).to(dtype=torch.float)
  1724. def forward(self, x):
  1725. x = self.conv(x)
  1726. x = self.bn(x)
  1727. return x
  1728. class SubModelWithoutFusion(nn.Module):
  1729. def __init__(self):
  1730. super().__init__()
  1731. self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
  1732. self.relu = nn.ReLU(inplace=False).to(dtype=torch.float)
  1733. def forward(self, x):
  1734. return self.relu(self.conv(x))
  1735. class ModelForFusion(nn.Module):
  1736. def __init__(self, qconfig):
  1737. super().__init__()
  1738. self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float)
  1739. self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
  1740. self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1741. self.sub1 = SubModelForFusion()
  1742. self.sub2 = SubModelWithoutFusion()
  1743. self.fc = nn.Linear(36, 10).to(dtype=torch.float)
  1744. self.quant = QuantStub()
  1745. self.dequant = DeQuantStub()
  1746. self.qconfig = qconfig
  1747. self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float)
  1748. self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
  1749. self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
  1750. self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1751. self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
  1752. self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
  1753. self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1754. # don't quantize sub2
  1755. self.sub2.qconfig = None
  1756. self.fc.qconfig = None
  1757. def forward(self, x):
  1758. x = x.squeeze(2)
  1759. x = self.quant(x)
  1760. x = self.conv3(x)
  1761. x = self.bn3(x)
  1762. x = self.relu4(x)
  1763. x = x.unsqueeze(2)
  1764. y = x.unsqueeze(2)
  1765. x = self.conv1(x)
  1766. x = self.bn1(x)
  1767. x = self.relu1(x)
  1768. x = self.sub1(x)
  1769. x = self.dequant(x)
  1770. x = self.sub2(x)
  1771. x = x.reshape(-1, 36).contiguous()
  1772. x = self.fc(x)
  1773. y = self.conv2(y)
  1774. y = self.relu2(y)
  1775. y = self.bn2(y)
  1776. y = self.relu3(y)
  1777. y = self.dequant(y)
  1778. return x
  1779. class ConvBNReLU(nn.Sequential):
  1780. def __init__(self):
  1781. super().__init__(
  1782. nn.Conv2d(3, 3, 1, 1, bias=False),
  1783. nn.BatchNorm2d(3),
  1784. nn.ReLU(inplace=False)
  1785. )
  1786. class ModelWithSequentialFusion(nn.Module):
  1787. def __init__(self):
  1788. super().__init__()
  1789. self.conv1 = nn.Conv2d(3, 3, 1)
  1790. self.relu1 = nn.ReLU(inplace=False)
  1791. layers = []
  1792. for i in range(3):
  1793. layers.append(ConvBNReLU())
  1794. self.features = nn.Sequential(*layers)
  1795. head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
  1796. self.classifier = nn.Sequential(*head)
  1797. self.seq = nn.Sequential()
  1798. self.quant = QuantStub()
  1799. self.dequant = DeQuantStub()
  1800. def forward(self, x):
  1801. x = self.quant(x)
  1802. x = self.conv1(x)
  1803. x = self.relu1(x)
  1804. x = self.features(x)
  1805. x = torch.reshape(x, (-1, 3 * 10 * 10))
  1806. x = self.classifier(x)
  1807. x = self.seq(x)
  1808. x = self.dequant(x)
  1809. return x
  1810. class ModelForFusionWithBias(nn.Module):
  1811. def __init__(self):
  1812. super().__init__()
  1813. self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float)
  1814. self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
  1815. self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
  1816. self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float)
  1817. self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float)
  1818. self.quant = QuantStub()
  1819. self.dequant = DeQuantStub()
  1820. def forward(self, x):
  1821. x = self.quant(x)
  1822. x = self.conv1(x)
  1823. x = self.bn1(x)
  1824. x = self.relu1(x)
  1825. x = self.conv2(x)
  1826. x = self.bn2(x)
  1827. x = self.dequant(x)
  1828. return x
  1829. class ModelForLinearBNFusion(nn.Module):
  1830. def __init__(self):
  1831. super().__init__()
  1832. self.fc = nn.Linear(20, 10)
  1833. self.bn = nn.BatchNorm1d(10)
  1834. nn.init.uniform_(self.bn.weight)
  1835. nn.init.uniform_(self.bn.bias)
  1836. def forward(self, x):
  1837. return self.bn(self.fc(x))
  1838. class DummyObserver(torch.nn.Module):
  1839. def calculate_qparams(self):
  1840. return 1.0, 0
  1841. def forward(self, x):
  1842. return x
  1843. class ModelForConvTransposeBNFusion(nn.Module):
  1844. def __init__(self):
  1845. super().__init__()
  1846. self.conv1 = nn.ConvTranspose1d(3, 3, 1)
  1847. self.bn1 = nn.BatchNorm1d(3)
  1848. self.conv2 = nn.ConvTranspose2d(3, 3, 1)
  1849. self.bn2 = nn.BatchNorm2d(3)
  1850. self.conv3 = nn.ConvTranspose3d(3, 3, 1)
  1851. self.bn3 = nn.BatchNorm3d(3)
  1852. def forward(self, x):
  1853. x = self.conv1(x)
  1854. x = self.bn1(x)
  1855. x = x.unsqueeze(2)
  1856. x = self.conv2(x)
  1857. x = self.bn2(x)
  1858. x = x.unsqueeze(2)
  1859. x = self.conv3(x)
  1860. x = self.bn3(x)
  1861. return x
  1862. class ModelWithFunctionals(torch.nn.Module):
  1863. def __init__(self):
  1864. super().__init__()
  1865. self.mycat = nnq.FloatFunctional()
  1866. self.myadd = nnq.FloatFunctional()
  1867. self.myadd_relu = nnq.FloatFunctional()
  1868. # Tracing doesnt work yet for c10 ops with scalar inputs
  1869. # https://github.com/pytorch/pytorch/issues/27097
  1870. # self.my_scalar_add = nnq.FloatFunctional()
  1871. # self.my_scalar_mul = nnq.FloatFunctional()
  1872. def forward(self, x):
  1873. y = self.mycat.cat([x, x, x])
  1874. z = self.myadd.add(y, y)
  1875. w = self.myadd_relu.add_relu(z, z)
  1876. # Tracing doesnt work yet for c10 ops with scalar inputs
  1877. # https://github.com/pytorch/pytorch/issues/27097
  1878. # w = self.my_scalar_add.add_scalar(w, -0.5)
  1879. # w = self.my_scalar_mul.mul_scalar(w, 0.5)
  1880. return w
  1881. class ResNetBase(torch.nn.Module):
  1882. def __init__(self):
  1883. super().__init__()
  1884. norm_layer = nn.BatchNorm2d
  1885. inplanes = 3
  1886. self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  1887. self.bn1 = norm_layer(inplanes)
  1888. self.relu1 = nn.ReLU()
  1889. self.relu2 = nn.ReLU()
  1890. self.downsample = torch.nn.Identity()
  1891. self.myop = nn.quantized.FloatFunctional()
  1892. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  1893. self.fc = torch.nn.Linear(inplanes, 1)
  1894. def forward(self, x):
  1895. out = self.conv1(x)
  1896. out = self.bn1(out)
  1897. out = self.relu1(out)
  1898. identity = self.downsample(x)
  1899. out = self.myop.add(out, identity)
  1900. out = self.relu2(out)
  1901. out = self.avgpool(out)
  1902. out = torch.flatten(out, 1)
  1903. out = self.fc(out)
  1904. return out
  1905. def fuse_model(self):
  1906. # TODO: remove this check and define two fuse_model function on this module
  1907. if self.training:
  1908. torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
  1909. else:
  1910. torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
  1911. class ModelMultipleOps(torch.nn.Module):
  1912. def __init__(self):
  1913. super().__init__()
  1914. norm_layer = nn.BatchNorm2d
  1915. inplanes = 3
  1916. self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  1917. self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  1918. self.bn1 = norm_layer(inplanes)
  1919. self.relu1 = nn.ReLU()
  1920. self.relu2 = nn.ReLU()
  1921. self.downsample = torch.nn.Identity()
  1922. self.skip_add = nn.quantized.FloatFunctional()
  1923. self.cat = nn.quantized.FloatFunctional()
  1924. self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
  1925. self.fc = nn.Linear(12, 6)
  1926. def forward(self, x):
  1927. out = self.conv1(x)
  1928. out = self.bn1(out)
  1929. out = self.relu1(out)
  1930. identity = self.downsample(x)
  1931. out = self.skip_add.add(out, identity)
  1932. out = self.relu2(out)
  1933. out = self.avgpool(out)
  1934. out = self.conv2(out)
  1935. out = torch.nn.functional.max_pool2d(out, 2, 2)
  1936. out = self.cat.cat([out, out])
  1937. out = out.reshape(-1, 3 * 2 * 2)
  1938. out = self.fc(out)
  1939. return out
  1940. # Model to ensure consistency of fake quant with true quant
  1941. # Average pooling and mean operations are not modelled
  1942. # accurately with fake-quant so this model does not
  1943. # contain those operations
  1944. class ModelMultipleOpsNoAvgPool(torch.nn.Module):
  1945. def __init__(self):
  1946. super().__init__()
  1947. norm_layer = nn.BatchNorm2d
  1948. inplanes = 3
  1949. self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  1950. self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
  1951. self.bn1 = norm_layer(inplanes)
  1952. self.relu1 = nn.ReLU()
  1953. self.relu2 = nn.ReLU()
  1954. self.skip_add = nn.quantized.FloatFunctional()
  1955. self.cat = nn.quantized.FloatFunctional()
  1956. self.maxpool = nn.MaxPool2d((4, 4))
  1957. self.fc = nn.Linear(12, 6)
  1958. def forward(self, x):
  1959. out = self.conv1(x)
  1960. out = self.bn1(out)
  1961. out = self.relu1(out)
  1962. skip = self.conv2(x)
  1963. out = self.skip_add.add(out, skip)
  1964. out = self.relu2(out)
  1965. out = self.maxpool(out)
  1966. out = self.conv2(out)
  1967. out = torch.nn.functional.max_pool2d(out, 2, 2)
  1968. out = self.cat.cat([out, out])
  1969. out = out.reshape(-1, 3 * 2 * 2)
  1970. out = self.fc(out)
  1971. return out
  1972. class EmbeddingBagModule(torch.nn.Module):
  1973. def __init__(self):
  1974. super().__init__()
  1975. self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
  1976. include_last_offset=True, scale_grad_by_freq=False, mode='sum')
  1977. def forward(self, indices, offsets, per_sample_weights):
  1978. return self.emb(indices, offsets, per_sample_weights)
  1979. class EmbeddingModule(torch.nn.Module):
  1980. def __init__(self):
  1981. super().__init__()
  1982. self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
  1983. def forward(self, indices):
  1984. return self.emb(indices)
  1985. class EmbeddingWithStaticLinear(torch.nn.Module):
  1986. def __init__(self):
  1987. super().__init__()
  1988. self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12)
  1989. self.fc = torch.nn.Linear(4, 2)
  1990. self.emb.qconfig = float_qparams_weight_only_qconfig
  1991. self.qconfig = default_qconfig
  1992. self.quant = QuantStub()
  1993. self.dequant = DeQuantStub()
  1994. def forward(self, indices, offsets, linear_in):
  1995. emb = self.emb(indices, offsets)
  1996. q_x = self.quant(linear_in)
  1997. fc = self.fc(q_x)
  1998. fc = self.dequant(fc)
  1999. features = torch.cat([fc] + [emb], dim=1)
  2000. return features
  2001. class DenseTopMLP(nn.Module):
  2002. def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None:
  2003. super().__init__()
  2004. self.dense_mlp = nn.Sequential(
  2005. nn.Linear(dense_dim, dense_out),
  2006. )
  2007. self.top_mlp = nn.Sequential(
  2008. nn.Linear(dense_out + embedding_dim, top_out_in),
  2009. nn.Linear(top_out_in, top_out_out),
  2010. )
  2011. def forward(
  2012. self,
  2013. sparse_feature: torch.Tensor,
  2014. dense: torch.Tensor,
  2015. ) -> torch.Tensor:
  2016. dense_feature = self.dense_mlp(dense)
  2017. features = torch.cat([dense_feature] + [sparse_feature], dim=1)
  2018. out = self.top_mlp(features)
  2019. return out
  2020. # thin wrapper around embedding bag, because tracing inside nn.Embedding
  2021. # bag is not supported at the moment and this is top level
  2022. class EmbBagWrapper(nn.Module):
  2023. def __init__(self, num_embeddings, embedding_dim):
  2024. super().__init__()
  2025. self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum')
  2026. def forward(self, indices, offsets):
  2027. return self.emb_bag(indices, offsets)
  2028. class SparseNNModel(nn.Module):
  2029. _NUM_EMBEDDINGS = 10
  2030. _EMBEDDING_DIM = 5
  2031. _DENSE_DIM = 4
  2032. _DENSE_OUTPUT = 2
  2033. _TOP_OUT_IN = 2
  2034. _TOP_OUT_OUT = 2
  2035. _TOP_MLP_DIM = 1
  2036. def __init__(self) -> None:
  2037. super().__init__()
  2038. self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM)
  2039. self.dense_top = DenseTopMLP(
  2040. self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN,
  2041. self._TOP_OUT_OUT)
  2042. def forward(
  2043. self,
  2044. sparse_indices: torch.Tensor,
  2045. sparse_offsets: torch.Tensor,
  2046. dense: torch.Tensor,
  2047. ) -> torch.Tensor:
  2048. sparse_feature = self.model_sparse(sparse_indices, sparse_offsets)
  2049. out = self.dense_top(sparse_feature, dense)
  2050. return out