regnet.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571
  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. from typing import Any, Callable, Dict, List, Optional, Tuple
  5. import torch
  6. from torch import nn, Tensor
  7. from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
  8. from ..transforms._presets import ImageClassification, InterpolationMode
  9. from ..utils import _log_api_usage_once
  10. from ._api import register_model, Weights, WeightsEnum
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
  13. __all__ = [
  14. "RegNet",
  15. "RegNet_Y_400MF_Weights",
  16. "RegNet_Y_800MF_Weights",
  17. "RegNet_Y_1_6GF_Weights",
  18. "RegNet_Y_3_2GF_Weights",
  19. "RegNet_Y_8GF_Weights",
  20. "RegNet_Y_16GF_Weights",
  21. "RegNet_Y_32GF_Weights",
  22. "RegNet_Y_128GF_Weights",
  23. "RegNet_X_400MF_Weights",
  24. "RegNet_X_800MF_Weights",
  25. "RegNet_X_1_6GF_Weights",
  26. "RegNet_X_3_2GF_Weights",
  27. "RegNet_X_8GF_Weights",
  28. "RegNet_X_16GF_Weights",
  29. "RegNet_X_32GF_Weights",
  30. "regnet_y_400mf",
  31. "regnet_y_800mf",
  32. "regnet_y_1_6gf",
  33. "regnet_y_3_2gf",
  34. "regnet_y_8gf",
  35. "regnet_y_16gf",
  36. "regnet_y_32gf",
  37. "regnet_y_128gf",
  38. "regnet_x_400mf",
  39. "regnet_x_800mf",
  40. "regnet_x_1_6gf",
  41. "regnet_x_3_2gf",
  42. "regnet_x_8gf",
  43. "regnet_x_16gf",
  44. "regnet_x_32gf",
  45. ]
  46. class SimpleStemIN(Conv2dNormActivation):
  47. """Simple stem for ImageNet: 3x3, BN, ReLU."""
  48. def __init__(
  49. self,
  50. width_in: int,
  51. width_out: int,
  52. norm_layer: Callable[..., nn.Module],
  53. activation_layer: Callable[..., nn.Module],
  54. ) -> None:
  55. super().__init__(
  56. width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
  57. )
  58. class BottleneckTransform(nn.Sequential):
  59. """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
  60. def __init__(
  61. self,
  62. width_in: int,
  63. width_out: int,
  64. stride: int,
  65. norm_layer: Callable[..., nn.Module],
  66. activation_layer: Callable[..., nn.Module],
  67. group_width: int,
  68. bottleneck_multiplier: float,
  69. se_ratio: Optional[float],
  70. ) -> None:
  71. layers: OrderedDict[str, nn.Module] = OrderedDict()
  72. w_b = int(round(width_out * bottleneck_multiplier))
  73. g = w_b // group_width
  74. layers["a"] = Conv2dNormActivation(
  75. width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
  76. )
  77. layers["b"] = Conv2dNormActivation(
  78. w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
  79. )
  80. if se_ratio:
  81. # The SE reduction ratio is defined with respect to the
  82. # beginning of the block
  83. width_se_out = int(round(se_ratio * width_in))
  84. layers["se"] = SqueezeExcitation(
  85. input_channels=w_b,
  86. squeeze_channels=width_se_out,
  87. activation=activation_layer,
  88. )
  89. layers["c"] = Conv2dNormActivation(
  90. w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
  91. )
  92. super().__init__(layers)
  93. class ResBottleneckBlock(nn.Module):
  94. """Residual bottleneck block: x + F(x), F = bottleneck transform."""
  95. def __init__(
  96. self,
  97. width_in: int,
  98. width_out: int,
  99. stride: int,
  100. norm_layer: Callable[..., nn.Module],
  101. activation_layer: Callable[..., nn.Module],
  102. group_width: int = 1,
  103. bottleneck_multiplier: float = 1.0,
  104. se_ratio: Optional[float] = None,
  105. ) -> None:
  106. super().__init__()
  107. # Use skip connection with projection if shape changes
  108. self.proj = None
  109. should_proj = (width_in != width_out) or (stride != 1)
  110. if should_proj:
  111. self.proj = Conv2dNormActivation(
  112. width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
  113. )
  114. self.f = BottleneckTransform(
  115. width_in,
  116. width_out,
  117. stride,
  118. norm_layer,
  119. activation_layer,
  120. group_width,
  121. bottleneck_multiplier,
  122. se_ratio,
  123. )
  124. self.activation = activation_layer(inplace=True)
  125. def forward(self, x: Tensor) -> Tensor:
  126. if self.proj is not None:
  127. x = self.proj(x) + self.f(x)
  128. else:
  129. x = x + self.f(x)
  130. return self.activation(x)
  131. class AnyStage(nn.Sequential):
  132. """AnyNet stage (sequence of blocks w/ the same output shape)."""
  133. def __init__(
  134. self,
  135. width_in: int,
  136. width_out: int,
  137. stride: int,
  138. depth: int,
  139. block_constructor: Callable[..., nn.Module],
  140. norm_layer: Callable[..., nn.Module],
  141. activation_layer: Callable[..., nn.Module],
  142. group_width: int,
  143. bottleneck_multiplier: float,
  144. se_ratio: Optional[float] = None,
  145. stage_index: int = 0,
  146. ) -> None:
  147. super().__init__()
  148. for i in range(depth):
  149. block = block_constructor(
  150. width_in if i == 0 else width_out,
  151. width_out,
  152. stride if i == 0 else 1,
  153. norm_layer,
  154. activation_layer,
  155. group_width,
  156. bottleneck_multiplier,
  157. se_ratio,
  158. )
  159. self.add_module(f"block{stage_index}-{i}", block)
  160. class BlockParams:
  161. def __init__(
  162. self,
  163. depths: List[int],
  164. widths: List[int],
  165. group_widths: List[int],
  166. bottleneck_multipliers: List[float],
  167. strides: List[int],
  168. se_ratio: Optional[float] = None,
  169. ) -> None:
  170. self.depths = depths
  171. self.widths = widths
  172. self.group_widths = group_widths
  173. self.bottleneck_multipliers = bottleneck_multipliers
  174. self.strides = strides
  175. self.se_ratio = se_ratio
  176. @classmethod
  177. def from_init_params(
  178. cls,
  179. depth: int,
  180. w_0: int,
  181. w_a: float,
  182. w_m: float,
  183. group_width: int,
  184. bottleneck_multiplier: float = 1.0,
  185. se_ratio: Optional[float] = None,
  186. **kwargs: Any,
  187. ) -> "BlockParams":
  188. """
  189. Programmatically compute all the per-block settings,
  190. given the RegNet parameters.
  191. The first step is to compute the quantized linear block parameters,
  192. in log space. Key parameters are:
  193. - `w_a` is the width progression slope
  194. - `w_0` is the initial width
  195. - `w_m` is the width stepping in the log space
  196. In other terms
  197. `log(block_width) = log(w_0) + w_m * block_capacity`,
  198. with `bock_capacity` ramping up following the w_0 and w_a params.
  199. This block width is finally quantized to multiples of 8.
  200. The second step is to compute the parameters per stage,
  201. taking into account the skip connection and the final 1x1 convolutions.
  202. We use the fact that the output width is constant within a stage.
  203. """
  204. QUANT = 8
  205. STRIDE = 2
  206. if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
  207. raise ValueError("Invalid RegNet settings")
  208. # Compute the block widths. Each stage has one unique block width
  209. widths_cont = torch.arange(depth) * w_a + w_0
  210. block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
  211. block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
  212. num_stages = len(set(block_widths))
  213. # Convert to per stage parameters
  214. split_helper = zip(
  215. block_widths + [0],
  216. [0] + block_widths,
  217. block_widths + [0],
  218. [0] + block_widths,
  219. )
  220. splits = [w != wp or r != rp for w, wp, r, rp in split_helper]
  221. stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
  222. stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist()
  223. strides = [STRIDE] * num_stages
  224. bottleneck_multipliers = [bottleneck_multiplier] * num_stages
  225. group_widths = [group_width] * num_stages
  226. # Adjust the compatibility of stage widths and group widths
  227. stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
  228. stage_widths, bottleneck_multipliers, group_widths
  229. )
  230. return cls(
  231. depths=stage_depths,
  232. widths=stage_widths,
  233. group_widths=group_widths,
  234. bottleneck_multipliers=bottleneck_multipliers,
  235. strides=strides,
  236. se_ratio=se_ratio,
  237. )
  238. def _get_expanded_params(self):
  239. return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
  240. @staticmethod
  241. def _adjust_widths_groups_compatibilty(
  242. stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
  243. ) -> Tuple[List[int], List[int]]:
  244. """
  245. Adjusts the compatibility of widths and groups,
  246. depending on the bottleneck ratio.
  247. """
  248. # Compute all widths for the current settings
  249. widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
  250. group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]
  251. # Compute the adjusted widths so that stage and group widths fit
  252. ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)]
  253. stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
  254. return stage_widths, group_widths_min
  255. class RegNet(nn.Module):
  256. def __init__(
  257. self,
  258. block_params: BlockParams,
  259. num_classes: int = 1000,
  260. stem_width: int = 32,
  261. stem_type: Optional[Callable[..., nn.Module]] = None,
  262. block_type: Optional[Callable[..., nn.Module]] = None,
  263. norm_layer: Optional[Callable[..., nn.Module]] = None,
  264. activation: Optional[Callable[..., nn.Module]] = None,
  265. ) -> None:
  266. super().__init__()
  267. _log_api_usage_once(self)
  268. if stem_type is None:
  269. stem_type = SimpleStemIN
  270. if norm_layer is None:
  271. norm_layer = nn.BatchNorm2d
  272. if block_type is None:
  273. block_type = ResBottleneckBlock
  274. if activation is None:
  275. activation = nn.ReLU
  276. # Ad hoc stem
  277. self.stem = stem_type(
  278. 3, # width_in
  279. stem_width,
  280. norm_layer,
  281. activation,
  282. )
  283. current_width = stem_width
  284. blocks = []
  285. for i, (
  286. width_out,
  287. stride,
  288. depth,
  289. group_width,
  290. bottleneck_multiplier,
  291. ) in enumerate(block_params._get_expanded_params()):
  292. blocks.append(
  293. (
  294. f"block{i+1}",
  295. AnyStage(
  296. current_width,
  297. width_out,
  298. stride,
  299. depth,
  300. block_type,
  301. norm_layer,
  302. activation,
  303. group_width,
  304. bottleneck_multiplier,
  305. block_params.se_ratio,
  306. stage_index=i + 1,
  307. ),
  308. )
  309. )
  310. current_width = width_out
  311. self.trunk_output = nn.Sequential(OrderedDict(blocks))
  312. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  313. self.fc = nn.Linear(in_features=current_width, out_features=num_classes)
  314. # Performs ResNet-style weight initialization
  315. for m in self.modules():
  316. if isinstance(m, nn.Conv2d):
  317. # Note that there is no bias due to BN
  318. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  319. nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
  320. elif isinstance(m, nn.BatchNorm2d):
  321. nn.init.ones_(m.weight)
  322. nn.init.zeros_(m.bias)
  323. elif isinstance(m, nn.Linear):
  324. nn.init.normal_(m.weight, mean=0.0, std=0.01)
  325. nn.init.zeros_(m.bias)
  326. def forward(self, x: Tensor) -> Tensor:
  327. x = self.stem(x)
  328. x = self.trunk_output(x)
  329. x = self.avgpool(x)
  330. x = x.flatten(start_dim=1)
  331. x = self.fc(x)
  332. return x
  333. def _regnet(
  334. block_params: BlockParams,
  335. weights: Optional[WeightsEnum],
  336. progress: bool,
  337. **kwargs: Any,
  338. ) -> RegNet:
  339. if weights is not None:
  340. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  341. norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
  342. model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
  343. if weights is not None:
  344. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  345. return model
  346. _COMMON_META: Dict[str, Any] = {
  347. "min_size": (1, 1),
  348. "categories": _IMAGENET_CATEGORIES,
  349. }
  350. _COMMON_SWAG_META = {
  351. **_COMMON_META,
  352. "recipe": "https://github.com/facebookresearch/SWAG",
  353. "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
  354. }
  355. class RegNet_Y_400MF_Weights(WeightsEnum):
  356. IMAGENET1K_V1 = Weights(
  357. url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
  358. transforms=partial(ImageClassification, crop_size=224),
  359. meta={
  360. **_COMMON_META,
  361. "num_params": 4344144,
  362. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  363. "_metrics": {
  364. "ImageNet-1K": {
  365. "acc@1": 74.046,
  366. "acc@5": 91.716,
  367. }
  368. },
  369. "_ops": 0.402,
  370. "_file_size": 16.806,
  371. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  372. },
  373. )
  374. IMAGENET1K_V2 = Weights(
  375. url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
  376. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  377. meta={
  378. **_COMMON_META,
  379. "num_params": 4344144,
  380. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  381. "_metrics": {
  382. "ImageNet-1K": {
  383. "acc@1": 75.804,
  384. "acc@5": 92.742,
  385. }
  386. },
  387. "_ops": 0.402,
  388. "_file_size": 16.806,
  389. "_docs": """
  390. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  391. `new training recipe
  392. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  393. """,
  394. },
  395. )
  396. DEFAULT = IMAGENET1K_V2
  397. class RegNet_Y_800MF_Weights(WeightsEnum):
  398. IMAGENET1K_V1 = Weights(
  399. url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
  400. transforms=partial(ImageClassification, crop_size=224),
  401. meta={
  402. **_COMMON_META,
  403. "num_params": 6432512,
  404. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  405. "_metrics": {
  406. "ImageNet-1K": {
  407. "acc@1": 76.420,
  408. "acc@5": 93.136,
  409. }
  410. },
  411. "_ops": 0.834,
  412. "_file_size": 24.774,
  413. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  414. },
  415. )
  416. IMAGENET1K_V2 = Weights(
  417. url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
  418. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  419. meta={
  420. **_COMMON_META,
  421. "num_params": 6432512,
  422. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  423. "_metrics": {
  424. "ImageNet-1K": {
  425. "acc@1": 78.828,
  426. "acc@5": 94.502,
  427. }
  428. },
  429. "_ops": 0.834,
  430. "_file_size": 24.774,
  431. "_docs": """
  432. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  433. `new training recipe
  434. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  435. """,
  436. },
  437. )
  438. DEFAULT = IMAGENET1K_V2
  439. class RegNet_Y_1_6GF_Weights(WeightsEnum):
  440. IMAGENET1K_V1 = Weights(
  441. url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
  442. transforms=partial(ImageClassification, crop_size=224),
  443. meta={
  444. **_COMMON_META,
  445. "num_params": 11202430,
  446. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  447. "_metrics": {
  448. "ImageNet-1K": {
  449. "acc@1": 77.950,
  450. "acc@5": 93.966,
  451. }
  452. },
  453. "_ops": 1.612,
  454. "_file_size": 43.152,
  455. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  456. },
  457. )
  458. IMAGENET1K_V2 = Weights(
  459. url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
  460. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  461. meta={
  462. **_COMMON_META,
  463. "num_params": 11202430,
  464. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  465. "_metrics": {
  466. "ImageNet-1K": {
  467. "acc@1": 80.876,
  468. "acc@5": 95.444,
  469. }
  470. },
  471. "_ops": 1.612,
  472. "_file_size": 43.152,
  473. "_docs": """
  474. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  475. `new training recipe
  476. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  477. """,
  478. },
  479. )
  480. DEFAULT = IMAGENET1K_V2
  481. class RegNet_Y_3_2GF_Weights(WeightsEnum):
  482. IMAGENET1K_V1 = Weights(
  483. url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
  484. transforms=partial(ImageClassification, crop_size=224),
  485. meta={
  486. **_COMMON_META,
  487. "num_params": 19436338,
  488. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  489. "_metrics": {
  490. "ImageNet-1K": {
  491. "acc@1": 78.948,
  492. "acc@5": 94.576,
  493. }
  494. },
  495. "_ops": 3.176,
  496. "_file_size": 74.567,
  497. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  498. },
  499. )
  500. IMAGENET1K_V2 = Weights(
  501. url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
  502. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  503. meta={
  504. **_COMMON_META,
  505. "num_params": 19436338,
  506. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  507. "_metrics": {
  508. "ImageNet-1K": {
  509. "acc@1": 81.982,
  510. "acc@5": 95.972,
  511. }
  512. },
  513. "_ops": 3.176,
  514. "_file_size": 74.567,
  515. "_docs": """
  516. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  517. `new training recipe
  518. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  519. """,
  520. },
  521. )
  522. DEFAULT = IMAGENET1K_V2
  523. class RegNet_Y_8GF_Weights(WeightsEnum):
  524. IMAGENET1K_V1 = Weights(
  525. url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
  526. transforms=partial(ImageClassification, crop_size=224),
  527. meta={
  528. **_COMMON_META,
  529. "num_params": 39381472,
  530. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  531. "_metrics": {
  532. "ImageNet-1K": {
  533. "acc@1": 80.032,
  534. "acc@5": 95.048,
  535. }
  536. },
  537. "_ops": 8.473,
  538. "_file_size": 150.701,
  539. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  540. },
  541. )
  542. IMAGENET1K_V2 = Weights(
  543. url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
  544. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  545. meta={
  546. **_COMMON_META,
  547. "num_params": 39381472,
  548. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  549. "_metrics": {
  550. "ImageNet-1K": {
  551. "acc@1": 82.828,
  552. "acc@5": 96.330,
  553. }
  554. },
  555. "_ops": 8.473,
  556. "_file_size": 150.701,
  557. "_docs": """
  558. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  559. `new training recipe
  560. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  561. """,
  562. },
  563. )
  564. DEFAULT = IMAGENET1K_V2
  565. class RegNet_Y_16GF_Weights(WeightsEnum):
  566. IMAGENET1K_V1 = Weights(
  567. url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
  568. transforms=partial(ImageClassification, crop_size=224),
  569. meta={
  570. **_COMMON_META,
  571. "num_params": 83590140,
  572. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
  573. "_metrics": {
  574. "ImageNet-1K": {
  575. "acc@1": 80.424,
  576. "acc@5": 95.240,
  577. }
  578. },
  579. "_ops": 15.912,
  580. "_file_size": 319.49,
  581. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  582. },
  583. )
  584. IMAGENET1K_V2 = Weights(
  585. url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
  586. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  587. meta={
  588. **_COMMON_META,
  589. "num_params": 83590140,
  590. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  591. "_metrics": {
  592. "ImageNet-1K": {
  593. "acc@1": 82.886,
  594. "acc@5": 96.328,
  595. }
  596. },
  597. "_ops": 15.912,
  598. "_file_size": 319.49,
  599. "_docs": """
  600. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  601. `new training recipe
  602. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  603. """,
  604. },
  605. )
  606. IMAGENET1K_SWAG_E2E_V1 = Weights(
  607. url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
  608. transforms=partial(
  609. ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
  610. ),
  611. meta={
  612. **_COMMON_SWAG_META,
  613. "num_params": 83590140,
  614. "_metrics": {
  615. "ImageNet-1K": {
  616. "acc@1": 86.012,
  617. "acc@5": 98.054,
  618. }
  619. },
  620. "_ops": 46.735,
  621. "_file_size": 319.49,
  622. "_docs": """
  623. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  624. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  625. """,
  626. },
  627. )
  628. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  629. url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
  630. transforms=partial(
  631. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  632. ),
  633. meta={
  634. **_COMMON_SWAG_META,
  635. "recipe": "https://github.com/pytorch/vision/pull/5793",
  636. "num_params": 83590140,
  637. "_metrics": {
  638. "ImageNet-1K": {
  639. "acc@1": 83.976,
  640. "acc@5": 97.244,
  641. }
  642. },
  643. "_ops": 15.912,
  644. "_file_size": 319.49,
  645. "_docs": """
  646. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  647. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  648. """,
  649. },
  650. )
  651. DEFAULT = IMAGENET1K_V2
  652. class RegNet_Y_32GF_Weights(WeightsEnum):
  653. IMAGENET1K_V1 = Weights(
  654. url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
  655. transforms=partial(ImageClassification, crop_size=224),
  656. meta={
  657. **_COMMON_META,
  658. "num_params": 145046770,
  659. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
  660. "_metrics": {
  661. "ImageNet-1K": {
  662. "acc@1": 80.878,
  663. "acc@5": 95.340,
  664. }
  665. },
  666. "_ops": 32.28,
  667. "_file_size": 554.076,
  668. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  669. },
  670. )
  671. IMAGENET1K_V2 = Weights(
  672. url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
  673. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  674. meta={
  675. **_COMMON_META,
  676. "num_params": 145046770,
  677. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  678. "_metrics": {
  679. "ImageNet-1K": {
  680. "acc@1": 83.368,
  681. "acc@5": 96.498,
  682. }
  683. },
  684. "_ops": 32.28,
  685. "_file_size": 554.076,
  686. "_docs": """
  687. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  688. `new training recipe
  689. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  690. """,
  691. },
  692. )
  693. IMAGENET1K_SWAG_E2E_V1 = Weights(
  694. url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
  695. transforms=partial(
  696. ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
  697. ),
  698. meta={
  699. **_COMMON_SWAG_META,
  700. "num_params": 145046770,
  701. "_metrics": {
  702. "ImageNet-1K": {
  703. "acc@1": 86.838,
  704. "acc@5": 98.362,
  705. }
  706. },
  707. "_ops": 94.826,
  708. "_file_size": 554.076,
  709. "_docs": """
  710. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  711. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  712. """,
  713. },
  714. )
  715. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  716. url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
  717. transforms=partial(
  718. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  719. ),
  720. meta={
  721. **_COMMON_SWAG_META,
  722. "recipe": "https://github.com/pytorch/vision/pull/5793",
  723. "num_params": 145046770,
  724. "_metrics": {
  725. "ImageNet-1K": {
  726. "acc@1": 84.622,
  727. "acc@5": 97.480,
  728. }
  729. },
  730. "_ops": 32.28,
  731. "_file_size": 554.076,
  732. "_docs": """
  733. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  734. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  735. """,
  736. },
  737. )
  738. DEFAULT = IMAGENET1K_V2
  739. class RegNet_Y_128GF_Weights(WeightsEnum):
  740. IMAGENET1K_SWAG_E2E_V1 = Weights(
  741. url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
  742. transforms=partial(
  743. ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
  744. ),
  745. meta={
  746. **_COMMON_SWAG_META,
  747. "num_params": 644812894,
  748. "_metrics": {
  749. "ImageNet-1K": {
  750. "acc@1": 88.228,
  751. "acc@5": 98.682,
  752. }
  753. },
  754. "_ops": 374.57,
  755. "_file_size": 2461.564,
  756. "_docs": """
  757. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  758. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  759. """,
  760. },
  761. )
  762. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  763. url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
  764. transforms=partial(
  765. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  766. ),
  767. meta={
  768. **_COMMON_SWAG_META,
  769. "recipe": "https://github.com/pytorch/vision/pull/5793",
  770. "num_params": 644812894,
  771. "_metrics": {
  772. "ImageNet-1K": {
  773. "acc@1": 86.068,
  774. "acc@5": 97.844,
  775. }
  776. },
  777. "_ops": 127.518,
  778. "_file_size": 2461.564,
  779. "_docs": """
  780. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  781. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  782. """,
  783. },
  784. )
  785. DEFAULT = IMAGENET1K_SWAG_E2E_V1
  786. class RegNet_X_400MF_Weights(WeightsEnum):
  787. IMAGENET1K_V1 = Weights(
  788. url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
  789. transforms=partial(ImageClassification, crop_size=224),
  790. meta={
  791. **_COMMON_META,
  792. "num_params": 5495976,
  793. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  794. "_metrics": {
  795. "ImageNet-1K": {
  796. "acc@1": 72.834,
  797. "acc@5": 90.950,
  798. }
  799. },
  800. "_ops": 0.414,
  801. "_file_size": 21.258,
  802. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  803. },
  804. )
  805. IMAGENET1K_V2 = Weights(
  806. url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
  807. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  808. meta={
  809. **_COMMON_META,
  810. "num_params": 5495976,
  811. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  812. "_metrics": {
  813. "ImageNet-1K": {
  814. "acc@1": 74.864,
  815. "acc@5": 92.322,
  816. }
  817. },
  818. "_ops": 0.414,
  819. "_file_size": 21.257,
  820. "_docs": """
  821. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  822. `new training recipe
  823. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  824. """,
  825. },
  826. )
  827. DEFAULT = IMAGENET1K_V2
  828. class RegNet_X_800MF_Weights(WeightsEnum):
  829. IMAGENET1K_V1 = Weights(
  830. url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
  831. transforms=partial(ImageClassification, crop_size=224),
  832. meta={
  833. **_COMMON_META,
  834. "num_params": 7259656,
  835. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  836. "_metrics": {
  837. "ImageNet-1K": {
  838. "acc@1": 75.212,
  839. "acc@5": 92.348,
  840. }
  841. },
  842. "_ops": 0.8,
  843. "_file_size": 27.945,
  844. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  845. },
  846. )
  847. IMAGENET1K_V2 = Weights(
  848. url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
  849. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  850. meta={
  851. **_COMMON_META,
  852. "num_params": 7259656,
  853. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  854. "_metrics": {
  855. "ImageNet-1K": {
  856. "acc@1": 77.522,
  857. "acc@5": 93.826,
  858. }
  859. },
  860. "_ops": 0.8,
  861. "_file_size": 27.945,
  862. "_docs": """
  863. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  864. `new training recipe
  865. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  866. """,
  867. },
  868. )
  869. DEFAULT = IMAGENET1K_V2
  870. class RegNet_X_1_6GF_Weights(WeightsEnum):
  871. IMAGENET1K_V1 = Weights(
  872. url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
  873. transforms=partial(ImageClassification, crop_size=224),
  874. meta={
  875. **_COMMON_META,
  876. "num_params": 9190136,
  877. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  878. "_metrics": {
  879. "ImageNet-1K": {
  880. "acc@1": 77.040,
  881. "acc@5": 93.440,
  882. }
  883. },
  884. "_ops": 1.603,
  885. "_file_size": 35.339,
  886. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  887. },
  888. )
  889. IMAGENET1K_V2 = Weights(
  890. url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
  891. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  892. meta={
  893. **_COMMON_META,
  894. "num_params": 9190136,
  895. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  896. "_metrics": {
  897. "ImageNet-1K": {
  898. "acc@1": 79.668,
  899. "acc@5": 94.922,
  900. }
  901. },
  902. "_ops": 1.603,
  903. "_file_size": 35.339,
  904. "_docs": """
  905. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  906. `new training recipe
  907. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  908. """,
  909. },
  910. )
  911. DEFAULT = IMAGENET1K_V2
  912. class RegNet_X_3_2GF_Weights(WeightsEnum):
  913. IMAGENET1K_V1 = Weights(
  914. url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
  915. transforms=partial(ImageClassification, crop_size=224),
  916. meta={
  917. **_COMMON_META,
  918. "num_params": 15296552,
  919. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  920. "_metrics": {
  921. "ImageNet-1K": {
  922. "acc@1": 78.364,
  923. "acc@5": 93.992,
  924. }
  925. },
  926. "_ops": 3.177,
  927. "_file_size": 58.756,
  928. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  929. },
  930. )
  931. IMAGENET1K_V2 = Weights(
  932. url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
  933. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  934. meta={
  935. **_COMMON_META,
  936. "num_params": 15296552,
  937. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  938. "_metrics": {
  939. "ImageNet-1K": {
  940. "acc@1": 81.196,
  941. "acc@5": 95.430,
  942. }
  943. },
  944. "_ops": 3.177,
  945. "_file_size": 58.756,
  946. "_docs": """
  947. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  948. `new training recipe
  949. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  950. """,
  951. },
  952. )
  953. DEFAULT = IMAGENET1K_V2
  954. class RegNet_X_8GF_Weights(WeightsEnum):
  955. IMAGENET1K_V1 = Weights(
  956. url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
  957. transforms=partial(ImageClassification, crop_size=224),
  958. meta={
  959. **_COMMON_META,
  960. "num_params": 39572648,
  961. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  962. "_metrics": {
  963. "ImageNet-1K": {
  964. "acc@1": 79.344,
  965. "acc@5": 94.686,
  966. }
  967. },
  968. "_ops": 7.995,
  969. "_file_size": 151.456,
  970. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  971. },
  972. )
  973. IMAGENET1K_V2 = Weights(
  974. url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
  975. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  976. meta={
  977. **_COMMON_META,
  978. "num_params": 39572648,
  979. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  980. "_metrics": {
  981. "ImageNet-1K": {
  982. "acc@1": 81.682,
  983. "acc@5": 95.678,
  984. }
  985. },
  986. "_ops": 7.995,
  987. "_file_size": 151.456,
  988. "_docs": """
  989. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  990. `new training recipe
  991. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  992. """,
  993. },
  994. )
  995. DEFAULT = IMAGENET1K_V2
  996. class RegNet_X_16GF_Weights(WeightsEnum):
  997. IMAGENET1K_V1 = Weights(
  998. url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
  999. transforms=partial(ImageClassification, crop_size=224),
  1000. meta={
  1001. **_COMMON_META,
  1002. "num_params": 54278536,
  1003. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  1004. "_metrics": {
  1005. "ImageNet-1K": {
  1006. "acc@1": 80.058,
  1007. "acc@5": 94.944,
  1008. }
  1009. },
  1010. "_ops": 15.941,
  1011. "_file_size": 207.627,
  1012. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  1013. },
  1014. )
  1015. IMAGENET1K_V2 = Weights(
  1016. url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
  1017. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  1018. meta={
  1019. **_COMMON_META,
  1020. "num_params": 54278536,
  1021. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  1022. "_metrics": {
  1023. "ImageNet-1K": {
  1024. "acc@1": 82.716,
  1025. "acc@5": 96.196,
  1026. }
  1027. },
  1028. "_ops": 15.941,
  1029. "_file_size": 207.627,
  1030. "_docs": """
  1031. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  1032. `new training recipe
  1033. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  1034. """,
  1035. },
  1036. )
  1037. DEFAULT = IMAGENET1K_V2
  1038. class RegNet_X_32GF_Weights(WeightsEnum):
  1039. IMAGENET1K_V1 = Weights(
  1040. url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
  1041. transforms=partial(ImageClassification, crop_size=224),
  1042. meta={
  1043. **_COMMON_META,
  1044. "num_params": 107811560,
  1045. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
  1046. "_metrics": {
  1047. "ImageNet-1K": {
  1048. "acc@1": 80.622,
  1049. "acc@5": 95.248,
  1050. }
  1051. },
  1052. "_ops": 31.736,
  1053. "_file_size": 412.039,
  1054. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  1055. },
  1056. )
  1057. IMAGENET1K_V2 = Weights(
  1058. url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
  1059. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  1060. meta={
  1061. **_COMMON_META,
  1062. "num_params": 107811560,
  1063. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  1064. "_metrics": {
  1065. "ImageNet-1K": {
  1066. "acc@1": 83.014,
  1067. "acc@5": 96.288,
  1068. }
  1069. },
  1070. "_ops": 31.736,
  1071. "_file_size": 412.039,
  1072. "_docs": """
  1073. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  1074. `new training recipe
  1075. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  1076. """,
  1077. },
  1078. )
  1079. DEFAULT = IMAGENET1K_V2
  1080. @register_model()
  1081. @handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
  1082. def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1083. """
  1084. Constructs a RegNetY_400MF architecture from
  1085. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1086. Args:
  1087. weights (:class:`~torchvision.models.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use.
  1088. See :class:`~torchvision.models.RegNet_Y_400MF_Weights` below for more details and possible values.
  1089. By default, no pretrained weights are used.
  1090. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1091. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1092. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1093. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1094. for more detail about the classes.
  1095. .. autoclass:: torchvision.models.RegNet_Y_400MF_Weights
  1096. :members:
  1097. """
  1098. weights = RegNet_Y_400MF_Weights.verify(weights)
  1099. params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
  1100. return _regnet(params, weights, progress, **kwargs)
  1101. @register_model()
  1102. @handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
  1103. def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1104. """
  1105. Constructs a RegNetY_800MF architecture from
  1106. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1107. Args:
  1108. weights (:class:`~torchvision.models.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use.
  1109. See :class:`~torchvision.models.RegNet_Y_800MF_Weights` below for more details and possible values.
  1110. By default, no pretrained weights are used.
  1111. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1112. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1113. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1114. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1115. for more detail about the classes.
  1116. .. autoclass:: torchvision.models.RegNet_Y_800MF_Weights
  1117. :members:
  1118. """
  1119. weights = RegNet_Y_800MF_Weights.verify(weights)
  1120. params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
  1121. return _regnet(params, weights, progress, **kwargs)
  1122. @register_model()
  1123. @handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
  1124. def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1125. """
  1126. Constructs a RegNetY_1.6GF architecture from
  1127. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1128. Args:
  1129. weights (:class:`~torchvision.models.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use.
  1130. See :class:`~torchvision.models.RegNet_Y_1_6GF_Weights` below for more details and possible values.
  1131. By default, no pretrained weights are used.
  1132. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1133. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1134. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1135. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1136. for more detail about the classes.
  1137. .. autoclass:: torchvision.models.RegNet_Y_1_6GF_Weights
  1138. :members:
  1139. """
  1140. weights = RegNet_Y_1_6GF_Weights.verify(weights)
  1141. params = BlockParams.from_init_params(
  1142. depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
  1143. )
  1144. return _regnet(params, weights, progress, **kwargs)
  1145. @register_model()
  1146. @handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
  1147. def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1148. """
  1149. Constructs a RegNetY_3.2GF architecture from
  1150. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1151. Args:
  1152. weights (:class:`~torchvision.models.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use.
  1153. See :class:`~torchvision.models.RegNet_Y_3_2GF_Weights` below for more details and possible values.
  1154. By default, no pretrained weights are used.
  1155. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1156. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1157. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1158. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1159. for more detail about the classes.
  1160. .. autoclass:: torchvision.models.RegNet_Y_3_2GF_Weights
  1161. :members:
  1162. """
  1163. weights = RegNet_Y_3_2GF_Weights.verify(weights)
  1164. params = BlockParams.from_init_params(
  1165. depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
  1166. )
  1167. return _regnet(params, weights, progress, **kwargs)
  1168. @register_model()
  1169. @handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
  1170. def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1171. """
  1172. Constructs a RegNetY_8GF architecture from
  1173. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1174. Args:
  1175. weights (:class:`~torchvision.models.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use.
  1176. See :class:`~torchvision.models.RegNet_Y_8GF_Weights` below for more details and possible values.
  1177. By default, no pretrained weights are used.
  1178. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1179. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1180. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1181. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1182. for more detail about the classes.
  1183. .. autoclass:: torchvision.models.RegNet_Y_8GF_Weights
  1184. :members:
  1185. """
  1186. weights = RegNet_Y_8GF_Weights.verify(weights)
  1187. params = BlockParams.from_init_params(
  1188. depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
  1189. )
  1190. return _regnet(params, weights, progress, **kwargs)
  1191. @register_model()
  1192. @handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
  1193. def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1194. """
  1195. Constructs a RegNetY_16GF architecture from
  1196. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1197. Args:
  1198. weights (:class:`~torchvision.models.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use.
  1199. See :class:`~torchvision.models.RegNet_Y_16GF_Weights` below for more details and possible values.
  1200. By default, no pretrained weights are used.
  1201. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1202. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1203. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1204. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1205. for more detail about the classes.
  1206. .. autoclass:: torchvision.models.RegNet_Y_16GF_Weights
  1207. :members:
  1208. """
  1209. weights = RegNet_Y_16GF_Weights.verify(weights)
  1210. params = BlockParams.from_init_params(
  1211. depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
  1212. )
  1213. return _regnet(params, weights, progress, **kwargs)
  1214. @register_model()
  1215. @handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
  1216. def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1217. """
  1218. Constructs a RegNetY_32GF architecture from
  1219. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1220. Args:
  1221. weights (:class:`~torchvision.models.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use.
  1222. See :class:`~torchvision.models.RegNet_Y_32GF_Weights` below for more details and possible values.
  1223. By default, no pretrained weights are used.
  1224. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1225. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1226. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1227. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1228. for more detail about the classes.
  1229. .. autoclass:: torchvision.models.RegNet_Y_32GF_Weights
  1230. :members:
  1231. """
  1232. weights = RegNet_Y_32GF_Weights.verify(weights)
  1233. params = BlockParams.from_init_params(
  1234. depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
  1235. )
  1236. return _regnet(params, weights, progress, **kwargs)
  1237. @register_model()
  1238. @handle_legacy_interface(weights=("pretrained", None))
  1239. def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1240. """
  1241. Constructs a RegNetY_128GF architecture from
  1242. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1243. Args:
  1244. weights (:class:`~torchvision.models.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use.
  1245. See :class:`~torchvision.models.RegNet_Y_128GF_Weights` below for more details and possible values.
  1246. By default, no pretrained weights are used.
  1247. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1248. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1249. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1250. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1251. for more detail about the classes.
  1252. .. autoclass:: torchvision.models.RegNet_Y_128GF_Weights
  1253. :members:
  1254. """
  1255. weights = RegNet_Y_128GF_Weights.verify(weights)
  1256. params = BlockParams.from_init_params(
  1257. depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
  1258. )
  1259. return _regnet(params, weights, progress, **kwargs)
  1260. @register_model()
  1261. @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
  1262. def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1263. """
  1264. Constructs a RegNetX_400MF architecture from
  1265. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1266. Args:
  1267. weights (:class:`~torchvision.models.RegNet_X_400MF_Weights`, optional): The pretrained weights to use.
  1268. See :class:`~torchvision.models.RegNet_X_400MF_Weights` below for more details and possible values.
  1269. By default, no pretrained weights are used.
  1270. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1271. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1272. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1273. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1274. for more detail about the classes.
  1275. .. autoclass:: torchvision.models.RegNet_X_400MF_Weights
  1276. :members:
  1277. """
  1278. weights = RegNet_X_400MF_Weights.verify(weights)
  1279. params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
  1280. return _regnet(params, weights, progress, **kwargs)
  1281. @register_model()
  1282. @handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
  1283. def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1284. """
  1285. Constructs a RegNetX_800MF architecture from
  1286. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1287. Args:
  1288. weights (:class:`~torchvision.models.RegNet_X_800MF_Weights`, optional): The pretrained weights to use.
  1289. See :class:`~torchvision.models.RegNet_X_800MF_Weights` below for more details and possible values.
  1290. By default, no pretrained weights are used.
  1291. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1292. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1293. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1294. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1295. for more detail about the classes.
  1296. .. autoclass:: torchvision.models.RegNet_X_800MF_Weights
  1297. :members:
  1298. """
  1299. weights = RegNet_X_800MF_Weights.verify(weights)
  1300. params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
  1301. return _regnet(params, weights, progress, **kwargs)
  1302. @register_model()
  1303. @handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
  1304. def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1305. """
  1306. Constructs a RegNetX_1.6GF architecture from
  1307. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1308. Args:
  1309. weights (:class:`~torchvision.models.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use.
  1310. See :class:`~torchvision.models.RegNet_X_1_6GF_Weights` below for more details and possible values.
  1311. By default, no pretrained weights are used.
  1312. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1313. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1314. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1315. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1316. for more detail about the classes.
  1317. .. autoclass:: torchvision.models.RegNet_X_1_6GF_Weights
  1318. :members:
  1319. """
  1320. weights = RegNet_X_1_6GF_Weights.verify(weights)
  1321. params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
  1322. return _regnet(params, weights, progress, **kwargs)
  1323. @register_model()
  1324. @handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
  1325. def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1326. """
  1327. Constructs a RegNetX_3.2GF architecture from
  1328. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1329. Args:
  1330. weights (:class:`~torchvision.models.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use.
  1331. See :class:`~torchvision.models.RegNet_X_3_2GF_Weights` below for more details and possible values.
  1332. By default, no pretrained weights are used.
  1333. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1334. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1335. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1336. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1337. for more detail about the classes.
  1338. .. autoclass:: torchvision.models.RegNet_X_3_2GF_Weights
  1339. :members:
  1340. """
  1341. weights = RegNet_X_3_2GF_Weights.verify(weights)
  1342. params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
  1343. return _regnet(params, weights, progress, **kwargs)
  1344. @register_model()
  1345. @handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
  1346. def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1347. """
  1348. Constructs a RegNetX_8GF architecture from
  1349. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1350. Args:
  1351. weights (:class:`~torchvision.models.RegNet_X_8GF_Weights`, optional): The pretrained weights to use.
  1352. See :class:`~torchvision.models.RegNet_X_8GF_Weights` below for more details and possible values.
  1353. By default, no pretrained weights are used.
  1354. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1355. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1356. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1357. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1358. for more detail about the classes.
  1359. .. autoclass:: torchvision.models.RegNet_X_8GF_Weights
  1360. :members:
  1361. """
  1362. weights = RegNet_X_8GF_Weights.verify(weights)
  1363. params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
  1364. return _regnet(params, weights, progress, **kwargs)
  1365. @register_model()
  1366. @handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
  1367. def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1368. """
  1369. Constructs a RegNetX_16GF architecture from
  1370. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1371. Args:
  1372. weights (:class:`~torchvision.models.RegNet_X_16GF_Weights`, optional): The pretrained weights to use.
  1373. See :class:`~torchvision.models.RegNet_X_16GF_Weights` below for more details and possible values.
  1374. By default, no pretrained weights are used.
  1375. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1376. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1377. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1378. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1379. for more detail about the classes.
  1380. .. autoclass:: torchvision.models.RegNet_X_16GF_Weights
  1381. :members:
  1382. """
  1383. weights = RegNet_X_16GF_Weights.verify(weights)
  1384. params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
  1385. return _regnet(params, weights, progress, **kwargs)
  1386. @register_model()
  1387. @handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
  1388. def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1389. """
  1390. Constructs a RegNetX_32GF architecture from
  1391. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1392. Args:
  1393. weights (:class:`~torchvision.models.RegNet_X_32GF_Weights`, optional): The pretrained weights to use.
  1394. See :class:`~torchvision.models.RegNet_X_32GF_Weights` below for more details and possible values.
  1395. By default, no pretrained weights are used.
  1396. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1397. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1398. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1399. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1400. for more detail about the classes.
  1401. .. autoclass:: torchvision.models.RegNet_X_32GF_Weights
  1402. :members:
  1403. """
  1404. weights = RegNet_X_32GF_Weights.verify(weights)
  1405. params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
  1406. return _regnet(params, weights, progress, **kwargs)