efficientnet.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131
  1. import copy
  2. import math
  3. from dataclasses import dataclass
  4. from functools import partial
  5. from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
  6. import torch
  7. from torch import nn, Tensor
  8. from torchvision.ops import StochasticDepth
  9. from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
  10. from ..transforms._presets import ImageClassification, InterpolationMode
  11. from ..utils import _log_api_usage_once
  12. from ._api import register_model, Weights, WeightsEnum
  13. from ._meta import _IMAGENET_CATEGORIES
  14. from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
  15. __all__ = [
  16. "EfficientNet",
  17. "EfficientNet_B0_Weights",
  18. "EfficientNet_B1_Weights",
  19. "EfficientNet_B2_Weights",
  20. "EfficientNet_B3_Weights",
  21. "EfficientNet_B4_Weights",
  22. "EfficientNet_B5_Weights",
  23. "EfficientNet_B6_Weights",
  24. "EfficientNet_B7_Weights",
  25. "EfficientNet_V2_S_Weights",
  26. "EfficientNet_V2_M_Weights",
  27. "EfficientNet_V2_L_Weights",
  28. "efficientnet_b0",
  29. "efficientnet_b1",
  30. "efficientnet_b2",
  31. "efficientnet_b3",
  32. "efficientnet_b4",
  33. "efficientnet_b5",
  34. "efficientnet_b6",
  35. "efficientnet_b7",
  36. "efficientnet_v2_s",
  37. "efficientnet_v2_m",
  38. "efficientnet_v2_l",
  39. ]
  40. @dataclass
  41. class _MBConvConfig:
  42. expand_ratio: float
  43. kernel: int
  44. stride: int
  45. input_channels: int
  46. out_channels: int
  47. num_layers: int
  48. block: Callable[..., nn.Module]
  49. @staticmethod
  50. def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
  51. return _make_divisible(channels * width_mult, 8, min_value)
  52. class MBConvConfig(_MBConvConfig):
  53. # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
  54. def __init__(
  55. self,
  56. expand_ratio: float,
  57. kernel: int,
  58. stride: int,
  59. input_channels: int,
  60. out_channels: int,
  61. num_layers: int,
  62. width_mult: float = 1.0,
  63. depth_mult: float = 1.0,
  64. block: Optional[Callable[..., nn.Module]] = None,
  65. ) -> None:
  66. input_channels = self.adjust_channels(input_channels, width_mult)
  67. out_channels = self.adjust_channels(out_channels, width_mult)
  68. num_layers = self.adjust_depth(num_layers, depth_mult)
  69. if block is None:
  70. block = MBConv
  71. super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
  72. @staticmethod
  73. def adjust_depth(num_layers: int, depth_mult: float):
  74. return int(math.ceil(num_layers * depth_mult))
  75. class FusedMBConvConfig(_MBConvConfig):
  76. # Stores information listed at Table 4 of the EfficientNetV2 paper
  77. def __init__(
  78. self,
  79. expand_ratio: float,
  80. kernel: int,
  81. stride: int,
  82. input_channels: int,
  83. out_channels: int,
  84. num_layers: int,
  85. block: Optional[Callable[..., nn.Module]] = None,
  86. ) -> None:
  87. if block is None:
  88. block = FusedMBConv
  89. super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
  90. class MBConv(nn.Module):
  91. def __init__(
  92. self,
  93. cnf: MBConvConfig,
  94. stochastic_depth_prob: float,
  95. norm_layer: Callable[..., nn.Module],
  96. se_layer: Callable[..., nn.Module] = SqueezeExcitation,
  97. ) -> None:
  98. super().__init__()
  99. if not (1 <= cnf.stride <= 2):
  100. raise ValueError("illegal stride value")
  101. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  102. layers: List[nn.Module] = []
  103. activation_layer = nn.SiLU
  104. # expand
  105. expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
  106. if expanded_channels != cnf.input_channels:
  107. layers.append(
  108. Conv2dNormActivation(
  109. cnf.input_channels,
  110. expanded_channels,
  111. kernel_size=1,
  112. norm_layer=norm_layer,
  113. activation_layer=activation_layer,
  114. )
  115. )
  116. # depthwise
  117. layers.append(
  118. Conv2dNormActivation(
  119. expanded_channels,
  120. expanded_channels,
  121. kernel_size=cnf.kernel,
  122. stride=cnf.stride,
  123. groups=expanded_channels,
  124. norm_layer=norm_layer,
  125. activation_layer=activation_layer,
  126. )
  127. )
  128. # squeeze and excitation
  129. squeeze_channels = max(1, cnf.input_channels // 4)
  130. layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
  131. # project
  132. layers.append(
  133. Conv2dNormActivation(
  134. expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  135. )
  136. )
  137. self.block = nn.Sequential(*layers)
  138. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  139. self.out_channels = cnf.out_channels
  140. def forward(self, input: Tensor) -> Tensor:
  141. result = self.block(input)
  142. if self.use_res_connect:
  143. result = self.stochastic_depth(result)
  144. result += input
  145. return result
  146. class FusedMBConv(nn.Module):
  147. def __init__(
  148. self,
  149. cnf: FusedMBConvConfig,
  150. stochastic_depth_prob: float,
  151. norm_layer: Callable[..., nn.Module],
  152. ) -> None:
  153. super().__init__()
  154. if not (1 <= cnf.stride <= 2):
  155. raise ValueError("illegal stride value")
  156. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  157. layers: List[nn.Module] = []
  158. activation_layer = nn.SiLU
  159. expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
  160. if expanded_channels != cnf.input_channels:
  161. # fused expand
  162. layers.append(
  163. Conv2dNormActivation(
  164. cnf.input_channels,
  165. expanded_channels,
  166. kernel_size=cnf.kernel,
  167. stride=cnf.stride,
  168. norm_layer=norm_layer,
  169. activation_layer=activation_layer,
  170. )
  171. )
  172. # project
  173. layers.append(
  174. Conv2dNormActivation(
  175. expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  176. )
  177. )
  178. else:
  179. layers.append(
  180. Conv2dNormActivation(
  181. cnf.input_channels,
  182. cnf.out_channels,
  183. kernel_size=cnf.kernel,
  184. stride=cnf.stride,
  185. norm_layer=norm_layer,
  186. activation_layer=activation_layer,
  187. )
  188. )
  189. self.block = nn.Sequential(*layers)
  190. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  191. self.out_channels = cnf.out_channels
  192. def forward(self, input: Tensor) -> Tensor:
  193. result = self.block(input)
  194. if self.use_res_connect:
  195. result = self.stochastic_depth(result)
  196. result += input
  197. return result
  198. class EfficientNet(nn.Module):
  199. def __init__(
  200. self,
  201. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
  202. dropout: float,
  203. stochastic_depth_prob: float = 0.2,
  204. num_classes: int = 1000,
  205. norm_layer: Optional[Callable[..., nn.Module]] = None,
  206. last_channel: Optional[int] = None,
  207. ) -> None:
  208. """
  209. EfficientNet V1 and V2 main class
  210. Args:
  211. inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
  212. dropout (float): The droupout probability
  213. stochastic_depth_prob (float): The stochastic depth probability
  214. num_classes (int): Number of classes
  215. norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
  216. last_channel (int): The number of channels on the penultimate layer
  217. """
  218. super().__init__()
  219. _log_api_usage_once(self)
  220. if not inverted_residual_setting:
  221. raise ValueError("The inverted_residual_setting should not be empty")
  222. elif not (
  223. isinstance(inverted_residual_setting, Sequence)
  224. and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
  225. ):
  226. raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
  227. if norm_layer is None:
  228. norm_layer = nn.BatchNorm2d
  229. layers: List[nn.Module] = []
  230. # building first layer
  231. firstconv_output_channels = inverted_residual_setting[0].input_channels
  232. layers.append(
  233. Conv2dNormActivation(
  234. 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
  235. )
  236. )
  237. # building inverted residual blocks
  238. total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
  239. stage_block_id = 0
  240. for cnf in inverted_residual_setting:
  241. stage: List[nn.Module] = []
  242. for _ in range(cnf.num_layers):
  243. # copy to avoid modifications. shallow copy is enough
  244. block_cnf = copy.copy(cnf)
  245. # overwrite info if not the first conv in the stage
  246. if stage:
  247. block_cnf.input_channels = block_cnf.out_channels
  248. block_cnf.stride = 1
  249. # adjust stochastic depth probability based on the depth of the stage block
  250. sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
  251. stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
  252. stage_block_id += 1
  253. layers.append(nn.Sequential(*stage))
  254. # building last several layers
  255. lastconv_input_channels = inverted_residual_setting[-1].out_channels
  256. lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
  257. layers.append(
  258. Conv2dNormActivation(
  259. lastconv_input_channels,
  260. lastconv_output_channels,
  261. kernel_size=1,
  262. norm_layer=norm_layer,
  263. activation_layer=nn.SiLU,
  264. )
  265. )
  266. self.features = nn.Sequential(*layers)
  267. self.avgpool = nn.AdaptiveAvgPool2d(1)
  268. self.classifier = nn.Sequential(
  269. nn.Dropout(p=dropout, inplace=True),
  270. nn.Linear(lastconv_output_channels, num_classes),
  271. )
  272. for m in self.modules():
  273. if isinstance(m, nn.Conv2d):
  274. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  275. if m.bias is not None:
  276. nn.init.zeros_(m.bias)
  277. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  278. nn.init.ones_(m.weight)
  279. nn.init.zeros_(m.bias)
  280. elif isinstance(m, nn.Linear):
  281. init_range = 1.0 / math.sqrt(m.out_features)
  282. nn.init.uniform_(m.weight, -init_range, init_range)
  283. nn.init.zeros_(m.bias)
  284. def _forward_impl(self, x: Tensor) -> Tensor:
  285. x = self.features(x)
  286. x = self.avgpool(x)
  287. x = torch.flatten(x, 1)
  288. x = self.classifier(x)
  289. return x
  290. def forward(self, x: Tensor) -> Tensor:
  291. return self._forward_impl(x)
  292. def _efficientnet(
  293. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
  294. dropout: float,
  295. last_channel: Optional[int],
  296. weights: Optional[WeightsEnum],
  297. progress: bool,
  298. **kwargs: Any,
  299. ) -> EfficientNet:
  300. if weights is not None:
  301. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  302. model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
  303. if weights is not None:
  304. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  305. return model
  306. def _efficientnet_conf(
  307. arch: str,
  308. **kwargs: Any,
  309. ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
  310. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
  311. if arch.startswith("efficientnet_b"):
  312. bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
  313. inverted_residual_setting = [
  314. bneck_conf(1, 3, 1, 32, 16, 1),
  315. bneck_conf(6, 3, 2, 16, 24, 2),
  316. bneck_conf(6, 5, 2, 24, 40, 2),
  317. bneck_conf(6, 3, 2, 40, 80, 3),
  318. bneck_conf(6, 5, 1, 80, 112, 3),
  319. bneck_conf(6, 5, 2, 112, 192, 4),
  320. bneck_conf(6, 3, 1, 192, 320, 1),
  321. ]
  322. last_channel = None
  323. elif arch.startswith("efficientnet_v2_s"):
  324. inverted_residual_setting = [
  325. FusedMBConvConfig(1, 3, 1, 24, 24, 2),
  326. FusedMBConvConfig(4, 3, 2, 24, 48, 4),
  327. FusedMBConvConfig(4, 3, 2, 48, 64, 4),
  328. MBConvConfig(4, 3, 2, 64, 128, 6),
  329. MBConvConfig(6, 3, 1, 128, 160, 9),
  330. MBConvConfig(6, 3, 2, 160, 256, 15),
  331. ]
  332. last_channel = 1280
  333. elif arch.startswith("efficientnet_v2_m"):
  334. inverted_residual_setting = [
  335. FusedMBConvConfig(1, 3, 1, 24, 24, 3),
  336. FusedMBConvConfig(4, 3, 2, 24, 48, 5),
  337. FusedMBConvConfig(4, 3, 2, 48, 80, 5),
  338. MBConvConfig(4, 3, 2, 80, 160, 7),
  339. MBConvConfig(6, 3, 1, 160, 176, 14),
  340. MBConvConfig(6, 3, 2, 176, 304, 18),
  341. MBConvConfig(6, 3, 1, 304, 512, 5),
  342. ]
  343. last_channel = 1280
  344. elif arch.startswith("efficientnet_v2_l"):
  345. inverted_residual_setting = [
  346. FusedMBConvConfig(1, 3, 1, 32, 32, 4),
  347. FusedMBConvConfig(4, 3, 2, 32, 64, 7),
  348. FusedMBConvConfig(4, 3, 2, 64, 96, 7),
  349. MBConvConfig(4, 3, 2, 96, 192, 10),
  350. MBConvConfig(6, 3, 1, 192, 224, 19),
  351. MBConvConfig(6, 3, 2, 224, 384, 25),
  352. MBConvConfig(6, 3, 1, 384, 640, 7),
  353. ]
  354. last_channel = 1280
  355. else:
  356. raise ValueError(f"Unsupported model type {arch}")
  357. return inverted_residual_setting, last_channel
  358. _COMMON_META: Dict[str, Any] = {
  359. "categories": _IMAGENET_CATEGORIES,
  360. }
  361. _COMMON_META_V1 = {
  362. **_COMMON_META,
  363. "min_size": (1, 1),
  364. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
  365. }
  366. _COMMON_META_V2 = {
  367. **_COMMON_META,
  368. "min_size": (33, 33),
  369. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
  370. }
  371. class EfficientNet_B0_Weights(WeightsEnum):
  372. IMAGENET1K_V1 = Weights(
  373. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  374. url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
  375. transforms=partial(
  376. ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
  377. ),
  378. meta={
  379. **_COMMON_META_V1,
  380. "num_params": 5288548,
  381. "_metrics": {
  382. "ImageNet-1K": {
  383. "acc@1": 77.692,
  384. "acc@5": 93.532,
  385. }
  386. },
  387. "_ops": 0.386,
  388. "_file_size": 20.451,
  389. "_docs": """These weights are ported from the original paper.""",
  390. },
  391. )
  392. DEFAULT = IMAGENET1K_V1
  393. class EfficientNet_B1_Weights(WeightsEnum):
  394. IMAGENET1K_V1 = Weights(
  395. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  396. url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
  397. transforms=partial(
  398. ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
  399. ),
  400. meta={
  401. **_COMMON_META_V1,
  402. "num_params": 7794184,
  403. "_metrics": {
  404. "ImageNet-1K": {
  405. "acc@1": 78.642,
  406. "acc@5": 94.186,
  407. }
  408. },
  409. "_ops": 0.687,
  410. "_file_size": 30.134,
  411. "_docs": """These weights are ported from the original paper.""",
  412. },
  413. )
  414. IMAGENET1K_V2 = Weights(
  415. url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
  416. transforms=partial(
  417. ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
  418. ),
  419. meta={
  420. **_COMMON_META_V1,
  421. "num_params": 7794184,
  422. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
  423. "_metrics": {
  424. "ImageNet-1K": {
  425. "acc@1": 79.838,
  426. "acc@5": 94.934,
  427. }
  428. },
  429. "_ops": 0.687,
  430. "_file_size": 30.136,
  431. "_docs": """
  432. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  433. `new training recipe
  434. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  435. """,
  436. },
  437. )
  438. DEFAULT = IMAGENET1K_V2
  439. class EfficientNet_B2_Weights(WeightsEnum):
  440. IMAGENET1K_V1 = Weights(
  441. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  442. url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
  443. transforms=partial(
  444. ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
  445. ),
  446. meta={
  447. **_COMMON_META_V1,
  448. "num_params": 9109994,
  449. "_metrics": {
  450. "ImageNet-1K": {
  451. "acc@1": 80.608,
  452. "acc@5": 95.310,
  453. }
  454. },
  455. "_ops": 1.088,
  456. "_file_size": 35.174,
  457. "_docs": """These weights are ported from the original paper.""",
  458. },
  459. )
  460. DEFAULT = IMAGENET1K_V1
  461. class EfficientNet_B3_Weights(WeightsEnum):
  462. IMAGENET1K_V1 = Weights(
  463. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  464. url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
  465. transforms=partial(
  466. ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
  467. ),
  468. meta={
  469. **_COMMON_META_V1,
  470. "num_params": 12233232,
  471. "_metrics": {
  472. "ImageNet-1K": {
  473. "acc@1": 82.008,
  474. "acc@5": 96.054,
  475. }
  476. },
  477. "_ops": 1.827,
  478. "_file_size": 47.184,
  479. "_docs": """These weights are ported from the original paper.""",
  480. },
  481. )
  482. DEFAULT = IMAGENET1K_V1
  483. class EfficientNet_B4_Weights(WeightsEnum):
  484. IMAGENET1K_V1 = Weights(
  485. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  486. url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
  487. transforms=partial(
  488. ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
  489. ),
  490. meta={
  491. **_COMMON_META_V1,
  492. "num_params": 19341616,
  493. "_metrics": {
  494. "ImageNet-1K": {
  495. "acc@1": 83.384,
  496. "acc@5": 96.594,
  497. }
  498. },
  499. "_ops": 4.394,
  500. "_file_size": 74.489,
  501. "_docs": """These weights are ported from the original paper.""",
  502. },
  503. )
  504. DEFAULT = IMAGENET1K_V1
  505. class EfficientNet_B5_Weights(WeightsEnum):
  506. IMAGENET1K_V1 = Weights(
  507. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  508. url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
  509. transforms=partial(
  510. ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
  511. ),
  512. meta={
  513. **_COMMON_META_V1,
  514. "num_params": 30389784,
  515. "_metrics": {
  516. "ImageNet-1K": {
  517. "acc@1": 83.444,
  518. "acc@5": 96.628,
  519. }
  520. },
  521. "_ops": 10.266,
  522. "_file_size": 116.864,
  523. "_docs": """These weights are ported from the original paper.""",
  524. },
  525. )
  526. DEFAULT = IMAGENET1K_V1
  527. class EfficientNet_B6_Weights(WeightsEnum):
  528. IMAGENET1K_V1 = Weights(
  529. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  530. url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
  531. transforms=partial(
  532. ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
  533. ),
  534. meta={
  535. **_COMMON_META_V1,
  536. "num_params": 43040704,
  537. "_metrics": {
  538. "ImageNet-1K": {
  539. "acc@1": 84.008,
  540. "acc@5": 96.916,
  541. }
  542. },
  543. "_ops": 19.068,
  544. "_file_size": 165.362,
  545. "_docs": """These weights are ported from the original paper.""",
  546. },
  547. )
  548. DEFAULT = IMAGENET1K_V1
  549. class EfficientNet_B7_Weights(WeightsEnum):
  550. IMAGENET1K_V1 = Weights(
  551. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  552. url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
  553. transforms=partial(
  554. ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
  555. ),
  556. meta={
  557. **_COMMON_META_V1,
  558. "num_params": 66347960,
  559. "_metrics": {
  560. "ImageNet-1K": {
  561. "acc@1": 84.122,
  562. "acc@5": 96.908,
  563. }
  564. },
  565. "_ops": 37.746,
  566. "_file_size": 254.675,
  567. "_docs": """These weights are ported from the original paper.""",
  568. },
  569. )
  570. DEFAULT = IMAGENET1K_V1
  571. class EfficientNet_V2_S_Weights(WeightsEnum):
  572. IMAGENET1K_V1 = Weights(
  573. url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
  574. transforms=partial(
  575. ImageClassification,
  576. crop_size=384,
  577. resize_size=384,
  578. interpolation=InterpolationMode.BILINEAR,
  579. ),
  580. meta={
  581. **_COMMON_META_V2,
  582. "num_params": 21458488,
  583. "_metrics": {
  584. "ImageNet-1K": {
  585. "acc@1": 84.228,
  586. "acc@5": 96.878,
  587. }
  588. },
  589. "_ops": 8.366,
  590. "_file_size": 82.704,
  591. "_docs": """
  592. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  593. `new training recipe
  594. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  595. """,
  596. },
  597. )
  598. DEFAULT = IMAGENET1K_V1
  599. class EfficientNet_V2_M_Weights(WeightsEnum):
  600. IMAGENET1K_V1 = Weights(
  601. url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
  602. transforms=partial(
  603. ImageClassification,
  604. crop_size=480,
  605. resize_size=480,
  606. interpolation=InterpolationMode.BILINEAR,
  607. ),
  608. meta={
  609. **_COMMON_META_V2,
  610. "num_params": 54139356,
  611. "_metrics": {
  612. "ImageNet-1K": {
  613. "acc@1": 85.112,
  614. "acc@5": 97.156,
  615. }
  616. },
  617. "_ops": 24.582,
  618. "_file_size": 208.01,
  619. "_docs": """
  620. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  621. `new training recipe
  622. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  623. """,
  624. },
  625. )
  626. DEFAULT = IMAGENET1K_V1
  627. class EfficientNet_V2_L_Weights(WeightsEnum):
  628. # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
  629. IMAGENET1K_V1 = Weights(
  630. url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
  631. transforms=partial(
  632. ImageClassification,
  633. crop_size=480,
  634. resize_size=480,
  635. interpolation=InterpolationMode.BICUBIC,
  636. mean=(0.5, 0.5, 0.5),
  637. std=(0.5, 0.5, 0.5),
  638. ),
  639. meta={
  640. **_COMMON_META_V2,
  641. "num_params": 118515272,
  642. "_metrics": {
  643. "ImageNet-1K": {
  644. "acc@1": 85.808,
  645. "acc@5": 97.788,
  646. }
  647. },
  648. "_ops": 56.08,
  649. "_file_size": 454.573,
  650. "_docs": """These weights are ported from the original paper.""",
  651. },
  652. )
  653. DEFAULT = IMAGENET1K_V1
  654. @register_model()
  655. @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
  656. def efficientnet_b0(
  657. *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
  658. ) -> EfficientNet:
  659. """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  660. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  661. Args:
  662. weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
  663. pretrained weights to use. See
  664. :class:`~torchvision.models.EfficientNet_B0_Weights` below for
  665. more details, and possible values. By default, no pre-trained
  666. weights are used.
  667. progress (bool, optional): If True, displays a progress bar of the
  668. download to stderr. Default is True.
  669. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  670. base class. Please refer to the `source code
  671. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  672. for more details about this class.
  673. .. autoclass:: torchvision.models.EfficientNet_B0_Weights
  674. :members:
  675. """
  676. weights = EfficientNet_B0_Weights.verify(weights)
  677. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
  678. return _efficientnet(
  679. inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
  680. )
  681. @register_model()
  682. @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
  683. def efficientnet_b1(
  684. *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
  685. ) -> EfficientNet:
  686. """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  687. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  688. Args:
  689. weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
  690. pretrained weights to use. See
  691. :class:`~torchvision.models.EfficientNet_B1_Weights` below for
  692. more details, and possible values. By default, no pre-trained
  693. weights are used.
  694. progress (bool, optional): If True, displays a progress bar of the
  695. download to stderr. Default is True.
  696. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  697. base class. Please refer to the `source code
  698. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  699. for more details about this class.
  700. .. autoclass:: torchvision.models.EfficientNet_B1_Weights
  701. :members:
  702. """
  703. weights = EfficientNet_B1_Weights.verify(weights)
  704. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
  705. return _efficientnet(
  706. inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
  707. )
  708. @register_model()
  709. @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
  710. def efficientnet_b2(
  711. *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
  712. ) -> EfficientNet:
  713. """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  714. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  715. Args:
  716. weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
  717. pretrained weights to use. See
  718. :class:`~torchvision.models.EfficientNet_B2_Weights` below for
  719. more details, and possible values. By default, no pre-trained
  720. weights are used.
  721. progress (bool, optional): If True, displays a progress bar of the
  722. download to stderr. Default is True.
  723. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  724. base class. Please refer to the `source code
  725. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  726. for more details about this class.
  727. .. autoclass:: torchvision.models.EfficientNet_B2_Weights
  728. :members:
  729. """
  730. weights = EfficientNet_B2_Weights.verify(weights)
  731. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
  732. return _efficientnet(
  733. inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
  734. )
  735. @register_model()
  736. @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
  737. def efficientnet_b3(
  738. *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
  739. ) -> EfficientNet:
  740. """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  741. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  742. Args:
  743. weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
  744. pretrained weights to use. See
  745. :class:`~torchvision.models.EfficientNet_B3_Weights` below for
  746. more details, and possible values. By default, no pre-trained
  747. weights are used.
  748. progress (bool, optional): If True, displays a progress bar of the
  749. download to stderr. Default is True.
  750. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  751. base class. Please refer to the `source code
  752. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  753. for more details about this class.
  754. .. autoclass:: torchvision.models.EfficientNet_B3_Weights
  755. :members:
  756. """
  757. weights = EfficientNet_B3_Weights.verify(weights)
  758. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
  759. return _efficientnet(
  760. inverted_residual_setting,
  761. kwargs.pop("dropout", 0.3),
  762. last_channel,
  763. weights,
  764. progress,
  765. **kwargs,
  766. )
  767. @register_model()
  768. @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
  769. def efficientnet_b4(
  770. *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
  771. ) -> EfficientNet:
  772. """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  773. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  774. Args:
  775. weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
  776. pretrained weights to use. See
  777. :class:`~torchvision.models.EfficientNet_B4_Weights` below for
  778. more details, and possible values. By default, no pre-trained
  779. weights are used.
  780. progress (bool, optional): If True, displays a progress bar of the
  781. download to stderr. Default is True.
  782. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  783. base class. Please refer to the `source code
  784. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  785. for more details about this class.
  786. .. autoclass:: torchvision.models.EfficientNet_B4_Weights
  787. :members:
  788. """
  789. weights = EfficientNet_B4_Weights.verify(weights)
  790. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
  791. return _efficientnet(
  792. inverted_residual_setting,
  793. kwargs.pop("dropout", 0.4),
  794. last_channel,
  795. weights,
  796. progress,
  797. **kwargs,
  798. )
  799. @register_model()
  800. @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
  801. def efficientnet_b5(
  802. *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
  803. ) -> EfficientNet:
  804. """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  805. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  806. Args:
  807. weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
  808. pretrained weights to use. See
  809. :class:`~torchvision.models.EfficientNet_B5_Weights` below for
  810. more details, and possible values. By default, no pre-trained
  811. weights are used.
  812. progress (bool, optional): If True, displays a progress bar of the
  813. download to stderr. Default is True.
  814. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  815. base class. Please refer to the `source code
  816. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  817. for more details about this class.
  818. .. autoclass:: torchvision.models.EfficientNet_B5_Weights
  819. :members:
  820. """
  821. weights = EfficientNet_B5_Weights.verify(weights)
  822. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
  823. return _efficientnet(
  824. inverted_residual_setting,
  825. kwargs.pop("dropout", 0.4),
  826. last_channel,
  827. weights,
  828. progress,
  829. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  830. **kwargs,
  831. )
  832. @register_model()
  833. @handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
  834. def efficientnet_b6(
  835. *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
  836. ) -> EfficientNet:
  837. """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  838. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  839. Args:
  840. weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
  841. pretrained weights to use. See
  842. :class:`~torchvision.models.EfficientNet_B6_Weights` below for
  843. more details, and possible values. By default, no pre-trained
  844. weights are used.
  845. progress (bool, optional): If True, displays a progress bar of the
  846. download to stderr. Default is True.
  847. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  848. base class. Please refer to the `source code
  849. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  850. for more details about this class.
  851. .. autoclass:: torchvision.models.EfficientNet_B6_Weights
  852. :members:
  853. """
  854. weights = EfficientNet_B6_Weights.verify(weights)
  855. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
  856. return _efficientnet(
  857. inverted_residual_setting,
  858. kwargs.pop("dropout", 0.5),
  859. last_channel,
  860. weights,
  861. progress,
  862. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  863. **kwargs,
  864. )
  865. @register_model()
  866. @handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
  867. def efficientnet_b7(
  868. *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
  869. ) -> EfficientNet:
  870. """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  871. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  872. Args:
  873. weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
  874. pretrained weights to use. See
  875. :class:`~torchvision.models.EfficientNet_B7_Weights` below for
  876. more details, and possible values. By default, no pre-trained
  877. weights are used.
  878. progress (bool, optional): If True, displays a progress bar of the
  879. download to stderr. Default is True.
  880. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  881. base class. Please refer to the `source code
  882. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  883. for more details about this class.
  884. .. autoclass:: torchvision.models.EfficientNet_B7_Weights
  885. :members:
  886. """
  887. weights = EfficientNet_B7_Weights.verify(weights)
  888. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
  889. return _efficientnet(
  890. inverted_residual_setting,
  891. kwargs.pop("dropout", 0.5),
  892. last_channel,
  893. weights,
  894. progress,
  895. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  896. **kwargs,
  897. )
  898. @register_model()
  899. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
  900. def efficientnet_v2_s(
  901. *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
  902. ) -> EfficientNet:
  903. """
  904. Constructs an EfficientNetV2-S architecture from
  905. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  906. Args:
  907. weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
  908. pretrained weights to use. See
  909. :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
  910. more details, and possible values. By default, no pre-trained
  911. weights are used.
  912. progress (bool, optional): If True, displays a progress bar of the
  913. download to stderr. Default is True.
  914. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  915. base class. Please refer to the `source code
  916. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  917. for more details about this class.
  918. .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
  919. :members:
  920. """
  921. weights = EfficientNet_V2_S_Weights.verify(weights)
  922. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
  923. return _efficientnet(
  924. inverted_residual_setting,
  925. kwargs.pop("dropout", 0.2),
  926. last_channel,
  927. weights,
  928. progress,
  929. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  930. **kwargs,
  931. )
  932. @register_model()
  933. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
  934. def efficientnet_v2_m(
  935. *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
  936. ) -> EfficientNet:
  937. """
  938. Constructs an EfficientNetV2-M architecture from
  939. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  940. Args:
  941. weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
  942. pretrained weights to use. See
  943. :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
  944. more details, and possible values. By default, no pre-trained
  945. weights are used.
  946. progress (bool, optional): If True, displays a progress bar of the
  947. download to stderr. Default is True.
  948. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  949. base class. Please refer to the `source code
  950. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  951. for more details about this class.
  952. .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
  953. :members:
  954. """
  955. weights = EfficientNet_V2_M_Weights.verify(weights)
  956. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
  957. return _efficientnet(
  958. inverted_residual_setting,
  959. kwargs.pop("dropout", 0.3),
  960. last_channel,
  961. weights,
  962. progress,
  963. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  964. **kwargs,
  965. )
  966. @register_model()
  967. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
  968. def efficientnet_v2_l(
  969. *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
  970. ) -> EfficientNet:
  971. """
  972. Constructs an EfficientNetV2-L architecture from
  973. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  974. Args:
  975. weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
  976. pretrained weights to use. See
  977. :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
  978. more details, and possible values. By default, no pre-trained
  979. weights are used.
  980. progress (bool, optional): If True, displays a progress bar of the
  981. download to stderr. Default is True.
  982. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  983. base class. Please refer to the `source code
  984. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  985. for more details about this class.
  986. .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
  987. :members:
  988. """
  989. weights = EfficientNet_V2_L_Weights.verify(weights)
  990. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
  991. return _efficientnet(
  992. inverted_residual_setting,
  993. kwargs.pop("dropout", 0.4),
  994. last_channel,
  995. weights,
  996. progress,
  997. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  998. **kwargs,
  999. )