autobackend.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import ast
  3. import contextlib
  4. import json
  5. import platform
  6. import zipfile
  7. from collections import OrderedDict, namedtuple
  8. from pathlib import Path
  9. from urllib.parse import urlparse
  10. import cv2
  11. import numpy as np
  12. import torch
  13. import torch.nn as nn
  14. from PIL import Image
  15. from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
  16. from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml
  17. from ultralytics.utils.downloads import attempt_download_asset, is_url
  18. from ultralytics.utils.ops import xywh2xyxy
  19. def check_class_names(names):
  20. """Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
  21. if isinstance(names, list): # names is a list
  22. names = dict(enumerate(names)) # convert to dict
  23. if isinstance(names, dict):
  24. # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
  25. names = {int(k): str(v) for k, v in names.items()}
  26. n = len(names)
  27. if max(names.keys()) >= n:
  28. raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
  29. f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
  30. if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
  31. map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
  32. names = {k: map[v] for k, v in names.items()}
  33. return names
  34. class AutoBackend(nn.Module):
  35. def __init__(self,
  36. weights='yolov8n.pt',
  37. device=torch.device('cpu'),
  38. dnn=False,
  39. data=None,
  40. fp16=False,
  41. fuse=True,
  42. verbose=True):
  43. """
  44. MultiBackend class for python inference on various platforms using Ultralytics YOLO.
  45. Args:
  46. weights (str): The path to the weights file. Default: 'yolov8n.pt'
  47. device (torch.device): The device to run the model on.
  48. dnn (bool): Use OpenCV DNN module for inference if True, defaults to False.
  49. data (str | Path | optional): Additional data.yaml file for class names.
  50. fp16 (bool): If True, use half precision. Default: False
  51. fuse (bool): Whether to fuse the model or not. Default: True
  52. verbose (bool): Whether to run in verbose mode or not. Default: True
  53. Supported formats and their naming conventions:
  54. | Format | Suffix |
  55. |-----------------------|------------------|
  56. | PyTorch | *.pt |
  57. | TorchScript | *.torchscript |
  58. | ONNX Runtime | *.onnx |
  59. | ONNX OpenCV DNN | *.onnx dnn=True |
  60. | OpenVINO | *.xml |
  61. | CoreML | *.mlpackage |
  62. | TensorRT | *.engine |
  63. | TensorFlow SavedModel | *_saved_model |
  64. | TensorFlow GraphDef | *.pb |
  65. | TensorFlow Lite | *.tflite |
  66. | TensorFlow Edge TPU | *_edgetpu.tflite |
  67. | PaddlePaddle | *_paddle_model |
  68. | ncnn | *_ncnn_model |
  69. """
  70. super().__init__()
  71. w = str(weights[0] if isinstance(weights, list) else weights)
  72. nn_module = isinstance(weights, torch.nn.Module)
  73. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
  74. self._model_type(w)
  75. fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
  76. nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
  77. stride = 32 # default stride
  78. model, metadata = None, None
  79. # Set device
  80. cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
  81. if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats
  82. device = torch.device('cpu')
  83. cuda = False
  84. # Download if not local
  85. if not (pt or triton or nn_module):
  86. w = attempt_download_asset(w)
  87. # Load model
  88. if nn_module: # in-memory PyTorch model
  89. model = weights.to(device)
  90. model = model.fuse(verbose=verbose) if fuse else model
  91. if hasattr(model, 'kpt_shape'):
  92. kpt_shape = model.kpt_shape # pose-only
  93. stride = max(int(model.stride.max()), 32) # model stride
  94. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  95. model.half() if fp16 else model.float()
  96. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  97. pt = True
  98. elif pt: # PyTorch
  99. from ultralytics.nn.tasks import attempt_load_weights
  100. model = attempt_load_weights(weights if isinstance(weights, list) else w,
  101. device=device,
  102. inplace=True,
  103. fuse=fuse)
  104. if hasattr(model, 'kpt_shape'):
  105. kpt_shape = model.kpt_shape # pose-only
  106. stride = max(int(model.stride.max()), 32) # model stride
  107. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  108. model.half() if fp16 else model.float()
  109. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  110. elif jit: # TorchScript
  111. LOGGER.info(f'Loading {w} for TorchScript inference...')
  112. extra_files = {'config.txt': ''} # model metadata
  113. model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
  114. model.half() if fp16 else model.float()
  115. if extra_files['config.txt']: # load metadata dict
  116. metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
  117. elif dnn: # ONNX OpenCV DNN
  118. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  119. check_requirements('opencv-python>=4.5.4')
  120. net = cv2.dnn.readNetFromONNX(w)
  121. elif onnx: # ONNX Runtime
  122. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  123. check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
  124. import onnxruntime
  125. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  126. session = onnxruntime.InferenceSession(w, providers=providers)
  127. output_names = [x.name for x in session.get_outputs()]
  128. metadata = session.get_modelmeta().custom_metadata_map # metadata
  129. elif xml: # OpenVINO
  130. LOGGER.info(f'Loading {w} for OpenVINO inference...')
  131. check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
  132. from openvino.runtime import Core, Layout, get_batch # noqa
  133. core = Core()
  134. w = Path(w)
  135. if not w.is_file(): # if not *.xml
  136. w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
  137. ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin'))
  138. if ov_model.get_parameters()[0].get_layout().empty:
  139. ov_model.get_parameters()[0].set_layout(Layout('NCHW'))
  140. batch_dim = get_batch(ov_model)
  141. if batch_dim.is_static:
  142. batch_size = batch_dim.get_length()
  143. ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device
  144. metadata = w.parent / 'metadata.yaml'
  145. elif engine: # TensorRT
  146. LOGGER.info(f'Loading {w} for TensorRT inference...')
  147. try:
  148. import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
  149. except ImportError:
  150. if LINUX:
  151. check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
  152. import tensorrt as trt # noqa
  153. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  154. if device.type == 'cpu':
  155. device = torch.device('cuda:0')
  156. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  157. logger = trt.Logger(trt.Logger.INFO)
  158. # Read file
  159. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  160. meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
  161. metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
  162. model = runtime.deserialize_cuda_engine(f.read()) # read engine
  163. context = model.create_execution_context()
  164. bindings = OrderedDict()
  165. output_names = []
  166. fp16 = False # default updated below
  167. dynamic = False
  168. for i in range(model.num_bindings):
  169. name = model.get_binding_name(i)
  170. dtype = trt.nptype(model.get_binding_dtype(i))
  171. if model.binding_is_input(i):
  172. if -1 in tuple(model.get_binding_shape(i)): # dynamic
  173. dynamic = True
  174. context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
  175. if dtype == np.float16:
  176. fp16 = True
  177. else: # output
  178. output_names.append(name)
  179. shape = tuple(context.get_binding_shape(i))
  180. im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
  181. bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
  182. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  183. batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
  184. elif coreml: # CoreML
  185. LOGGER.info(f'Loading {w} for CoreML inference...')
  186. import coremltools as ct
  187. model = ct.models.MLModel(w)
  188. metadata = dict(model.user_defined_metadata)
  189. elif saved_model: # TF SavedModel
  190. LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
  191. import tensorflow as tf
  192. keras = False # assume TF1 saved_model
  193. model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
  194. metadata = Path(w) / 'metadata.yaml'
  195. elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  196. LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
  197. import tensorflow as tf
  198. from ultralytics.engine.exporter import gd_outputs
  199. def wrap_frozen_graph(gd, inputs, outputs):
  200. """Wrap frozen graphs for deployment."""
  201. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
  202. ge = x.graph.as_graph_element
  203. return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
  204. gd = tf.Graph().as_graph_def() # TF GraphDef
  205. with open(w, 'rb') as f:
  206. gd.ParseFromString(f.read())
  207. frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
  208. elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  209. try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
  210. from tflite_runtime.interpreter import Interpreter, load_delegate
  211. except ImportError:
  212. import tensorflow as tf
  213. Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
  214. if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
  215. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  216. delegate = {
  217. 'Linux': 'libedgetpu.so.1',
  218. 'Darwin': 'libedgetpu.1.dylib',
  219. 'Windows': 'edgetpu.dll'}[platform.system()]
  220. interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
  221. else: # TFLite
  222. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  223. interpreter = Interpreter(model_path=w) # load TFLite model
  224. interpreter.allocate_tensors() # allocate
  225. input_details = interpreter.get_input_details() # inputs
  226. output_details = interpreter.get_output_details() # outputs
  227. # Load metadata
  228. with contextlib.suppress(zipfile.BadZipFile):
  229. with zipfile.ZipFile(w, 'r') as model:
  230. meta_file = model.namelist()[0]
  231. metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
  232. elif tfjs: # TF.js
  233. raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
  234. elif paddle: # PaddlePaddle
  235. LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
  236. check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
  237. import paddle.inference as pdi # noqa
  238. w = Path(w)
  239. if not w.is_file(): # if not *.pdmodel
  240. w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
  241. config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
  242. if cuda:
  243. config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
  244. predictor = pdi.create_predictor(config)
  245. input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
  246. output_names = predictor.get_output_names()
  247. metadata = w.parents[1] / 'metadata.yaml'
  248. elif ncnn: # ncnn
  249. LOGGER.info(f'Loading {w} for ncnn inference...')
  250. check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
  251. import ncnn as pyncnn
  252. net = pyncnn.Net()
  253. net.opt.use_vulkan_compute = cuda
  254. w = Path(w)
  255. if not w.is_file(): # if not *.param
  256. w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
  257. net.load_param(str(w))
  258. net.load_model(str(w.with_suffix('.bin')))
  259. metadata = w.parent / 'metadata.yaml'
  260. elif triton: # NVIDIA Triton Inference Server
  261. """TODO
  262. check_requirements('tritonclient[all]')
  263. from utils.triton import TritonRemoteModel
  264. model = TritonRemoteModel(url=w)
  265. nhwc = model.runtime.startswith("tensorflow")
  266. """
  267. raise NotImplementedError('Triton Inference Server is not currently supported.')
  268. else:
  269. from ultralytics.engine.exporter import export_formats
  270. raise TypeError(f"model='{w}' is not a supported model format. "
  271. 'See https://docs.ultralytics.com/modes/predict for help.'
  272. f'\n\n{export_formats()}')
  273. # Load external metadata YAML
  274. if isinstance(metadata, (str, Path)) and Path(metadata).exists():
  275. metadata = yaml_load(metadata)
  276. if metadata:
  277. for k, v in metadata.items():
  278. if k in ('stride', 'batch'):
  279. metadata[k] = int(v)
  280. elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
  281. metadata[k] = eval(v)
  282. stride = metadata['stride']
  283. task = metadata['task']
  284. batch = metadata['batch']
  285. imgsz = metadata['imgsz']
  286. names = metadata['names']
  287. kpt_shape = metadata.get('kpt_shape')
  288. elif not (pt or triton or nn_module):
  289. LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
  290. # Check names
  291. if 'names' not in locals(): # names missing
  292. names = self._apply_default_class_names(data)
  293. names = check_class_names(names)
  294. self.__dict__.update(locals()) # assign all variables to self
  295. def forward(self, im, augment=False, visualize=False):
  296. """
  297. Runs inference on the YOLOv8 MultiBackend model.
  298. Args:
  299. im (torch.Tensor): The image tensor to perform inference on.
  300. augment (bool): whether to perform data augmentation during inference, defaults to False
  301. visualize (bool): whether to visualize the output predictions, defaults to False
  302. Returns:
  303. (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
  304. """
  305. b, ch, h, w = im.shape # batch, channel, height, width
  306. if self.fp16 and im.dtype != torch.float16:
  307. im = im.half() # to FP16
  308. if self.nhwc:
  309. im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
  310. if self.pt or self.nn_module: # PyTorch
  311. y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
  312. elif self.jit: # TorchScript
  313. y = self.model(im)
  314. elif self.dnn: # ONNX OpenCV DNN
  315. im = im.cpu().numpy() # torch to numpy
  316. self.net.setInput(im)
  317. y = self.net.forward()
  318. elif self.onnx: # ONNX Runtime
  319. im = im.cpu().numpy() # torch to numpy
  320. y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
  321. elif self.xml: # OpenVINO
  322. im = im.cpu().numpy() # FP32
  323. y = list(self.ov_compiled_model(im).values())
  324. elif self.engine: # TensorRT
  325. if self.dynamic and im.shape != self.bindings['images'].shape:
  326. i = self.model.get_binding_index('images')
  327. self.context.set_binding_shape(i, im.shape) # reshape if dynamic
  328. self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
  329. for name in self.output_names:
  330. i = self.model.get_binding_index(name)
  331. self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
  332. s = self.bindings['images'].shape
  333. assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
  334. self.binding_addrs['images'] = int(im.data_ptr())
  335. self.context.execute_v2(list(self.binding_addrs.values()))
  336. y = [self.bindings[x].data for x in sorted(self.output_names)]
  337. elif self.coreml: # CoreML
  338. im = im[0].cpu().numpy()
  339. im_pil = Image.fromarray((im * 255).astype('uint8'))
  340. # im = im.resize((192, 320), Image.BILINEAR)
  341. y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
  342. if 'confidence' in y:
  343. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  344. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  345. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  346. elif len(y) == 1: # classification model
  347. y = list(y.values())
  348. elif len(y) == 2: # segmentation model
  349. y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
  350. elif self.paddle: # PaddlePaddle
  351. im = im.cpu().numpy().astype(np.float32)
  352. self.input_handle.copy_from_cpu(im)
  353. self.predictor.run()
  354. y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
  355. elif self.ncnn: # ncnn
  356. mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
  357. ex = self.net.create_extractor()
  358. input_names, output_names = self.net.input_names(), self.net.output_names()
  359. ex.input(input_names[0], mat_in)
  360. y = []
  361. for output_name in output_names:
  362. mat_out = self.pyncnn.Mat()
  363. ex.extract(output_name, mat_out)
  364. y.append(np.array(mat_out)[None])
  365. elif self.triton: # NVIDIA Triton Inference Server
  366. y = self.model(im)
  367. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  368. im = im.cpu().numpy()
  369. if self.saved_model: # SavedModel
  370. y = self.model(im, training=False) if self.keras else self.model(im)
  371. if not isinstance(y, list):
  372. y = [y]
  373. elif self.pb: # GraphDef
  374. y = self.frozen_func(x=self.tf.constant(im))
  375. if len(y) == 2 and len(self.names) == 999: # segments and names not defined
  376. ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
  377. nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
  378. self.names = {i: f'class{i}' for i in range(nc)}
  379. else: # Lite or Edge TPU
  380. details = self.input_details[0]
  381. integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
  382. if integer:
  383. scale, zero_point = details['quantization']
  384. im = (im / scale + zero_point).astype(details['dtype']) # de-scale
  385. self.interpreter.set_tensor(details['index'], im)
  386. self.interpreter.invoke()
  387. y = []
  388. for output in self.output_details:
  389. x = self.interpreter.get_tensor(output['index'])
  390. if integer:
  391. scale, zero_point = output['quantization']
  392. x = (x.astype(np.float32) - zero_point) * scale # re-scale
  393. if x.ndim > 2: # if task is not classification
  394. # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
  395. # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
  396. x[:, [0, 2]] *= w
  397. x[:, [1, 3]] *= h
  398. y.append(x)
  399. # TF segment fixes: export is reversed vs ONNX export and protos are transposed
  400. if len(y) == 2: # segment with (det, proto) output order reversed
  401. if len(y[1].shape) != 4:
  402. y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
  403. y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
  404. y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
  405. # for x in y:
  406. # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
  407. if isinstance(y, (list, tuple)):
  408. return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
  409. else:
  410. return self.from_numpy(y)
  411. def from_numpy(self, x):
  412. """
  413. Convert a numpy array to a tensor.
  414. Args:
  415. x (np.ndarray): The array to be converted.
  416. Returns:
  417. (torch.Tensor): The converted tensor
  418. """
  419. return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
  420. def warmup(self, imgsz=(1, 3, 640, 640)):
  421. """
  422. Warm up the model by running one forward pass with a dummy input.
  423. Args:
  424. imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
  425. Returns:
  426. (None): This method runs the forward pass and don't return any value
  427. """
  428. warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
  429. if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
  430. im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
  431. for _ in range(2 if self.jit else 1): #
  432. self.forward(im) # warmup
  433. @staticmethod
  434. def _apply_default_class_names(data):
  435. """Applies default class names to an input YAML file or returns numerical class names."""
  436. with contextlib.suppress(Exception):
  437. return yaml_load(check_yaml(data))['names']
  438. return {i: f'class{i}' for i in range(999)} # return default if above errors
  439. @staticmethod
  440. def _model_type(p='path/to/model.pt'):
  441. """
  442. This function takes a path to a model file and returns the model type
  443. Args:
  444. p: path to the model file. Defaults to path/to/model.pt
  445. """
  446. # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
  447. # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
  448. from ultralytics.engine.exporter import export_formats
  449. sf = list(export_formats().Suffix) # export suffixes
  450. if not is_url(p, check=False) and not isinstance(p, str):
  451. check_suffix(p, sf) # checks
  452. name = Path(p).name
  453. types = [s in name for s in sf]
  454. types[5] |= name.endswith('.mlmodel') # retain support for older Apple CoreML *.mlmodel formats
  455. types[8] &= not types[9] # tflite &= not edgetpu
  456. if any(types):
  457. triton = False
  458. else:
  459. url = urlparse(p) # if url may be Triton inference server
  460. triton = all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
  461. return types + [triton]