exporter.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
  4. Format | `format=argument` | Model
  5. --- | --- | ---
  6. PyTorch | - | yolov8n.pt
  7. TorchScript | `torchscript` | yolov8n.torchscript
  8. ONNX | `onnx` | yolov8n.onnx
  9. OpenVINO | `openvino` | yolov8n_openvino_model/
  10. TensorRT | `engine` | yolov8n.engine
  11. CoreML | `coreml` | yolov8n.mlpackage
  12. TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
  13. TensorFlow GraphDef | `pb` | yolov8n.pb
  14. TensorFlow Lite | `tflite` | yolov8n.tflite
  15. TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
  16. TensorFlow.js | `tfjs` | yolov8n_web_model/
  17. PaddlePaddle | `paddle` | yolov8n_paddle_model/
  18. ncnn | `ncnn` | yolov8n_ncnn_model/
  19. Requirements:
  20. $ pip install "ultralytics[export]"
  21. Python:
  22. from ultralytics import YOLO
  23. model = YOLO('yolov8n.pt')
  24. results = model.export(format='onnx')
  25. CLI:
  26. $ yolo mode=export model=yolov8n.pt format=onnx
  27. Inference:
  28. $ yolo predict model=yolov8n.pt # PyTorch
  29. yolov8n.torchscript # TorchScript
  30. yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
  31. yolov8n_openvino_model # OpenVINO
  32. yolov8n.engine # TensorRT
  33. yolov8n.mlpackage # CoreML (macOS-only)
  34. yolov8n_saved_model # TensorFlow SavedModel
  35. yolov8n.pb # TensorFlow GraphDef
  36. yolov8n.tflite # TensorFlow Lite
  37. yolov8n_edgetpu.tflite # TensorFlow Edge TPU
  38. yolov8n_paddle_model # PaddlePaddle
  39. TensorFlow.js:
  40. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  41. $ npm install
  42. $ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
  43. $ npm start
  44. """
  45. import json
  46. import os
  47. import shutil
  48. import subprocess
  49. import time
  50. import warnings
  51. from copy import deepcopy
  52. from datetime import datetime
  53. from pathlib import Path
  54. import torch
  55. from ultralytics.cfg import get_cfg
  56. from ultralytics.nn.autobackend import check_class_names
  57. from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
  58. from ultralytics.nn.tasks import DetectionModel, SegmentationModel
  59. from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks,
  60. colorstr, get_default_args, yaml_save)
  61. from ultralytics.utils.checks import check_imgsz, check_requirements, check_version
  62. from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
  63. from ultralytics.utils.files import file_size, spaces_in_path
  64. from ultralytics.utils.ops import Profile
  65. from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
  66. def export_formats():
  67. """YOLOv8 export formats."""
  68. import pandas
  69. x = [
  70. ['PyTorch', '-', '.pt', True, True],
  71. ['TorchScript', 'torchscript', '.torchscript', True, True],
  72. ['ONNX', 'onnx', '.onnx', True, True],
  73. ['OpenVINO', 'openvino', '_openvino_model', True, False],
  74. ['TensorRT', 'engine', '.engine', False, True],
  75. ['CoreML', 'coreml', '.mlpackage', True, False],
  76. ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
  77. ['TensorFlow GraphDef', 'pb', '.pb', True, True],
  78. ['TensorFlow Lite', 'tflite', '.tflite', True, False],
  79. ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
  80. ['TensorFlow.js', 'tfjs', '_web_model', True, False],
  81. ['PaddlePaddle', 'paddle', '_paddle_model', True, True],
  82. ['ncnn', 'ncnn', '_ncnn_model', True, True], ]
  83. return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
  84. def gd_outputs(gd):
  85. """TensorFlow GraphDef model output node names."""
  86. name_list, input_list = [], []
  87. for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
  88. name_list.append(node.name)
  89. input_list.extend(node.input)
  90. return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
  91. def try_export(inner_func):
  92. """YOLOv8 export decorator, i..e @try_export."""
  93. inner_args = get_default_args(inner_func)
  94. def outer_func(*args, **kwargs):
  95. """Export a model."""
  96. prefix = inner_args['prefix']
  97. try:
  98. with Profile() as dt:
  99. f, model = inner_func(*args, **kwargs)
  100. LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
  101. return f, model
  102. except Exception as e:
  103. LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
  104. raise e
  105. return outer_func
  106. class Exporter:
  107. """
  108. A class for exporting a model.
  109. Attributes:
  110. args (SimpleNamespace): Configuration for the exporter.
  111. callbacks (list, optional): List of callback functions. Defaults to None.
  112. """
  113. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  114. """
  115. Initializes the Exporter class.
  116. Args:
  117. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  118. overrides (dict, optional): Configuration overrides. Defaults to None.
  119. _callbacks (list, optional): List of callback functions. Defaults to None.
  120. """
  121. self.args = get_cfg(cfg, overrides)
  122. self.callbacks = _callbacks or callbacks.get_default_callbacks()
  123. callbacks.add_integration_callbacks(self)
  124. @smart_inference_mode()
  125. def __call__(self, model=None):
  126. """Returns list of exported files/dirs after running callbacks."""
  127. self.run_callbacks('on_export_start')
  128. t = time.time()
  129. format = self.args.format.lower() # to lowercase
  130. if format in ('tensorrt', 'trt'): # 'engine' aliases
  131. format = 'engine'
  132. if format in ('mlmodel', 'mlpackage', 'mlprogram', 'apple', 'ios'): # 'coreml' aliases
  133. format = 'coreml'
  134. fmts = tuple(export_formats()['Argument'][1:]) # available export formats
  135. flags = [x == format for x in fmts]
  136. if sum(flags) != 1:
  137. raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
  138. jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
  139. # Load PyTorch model
  140. self.device = select_device('cpu' if self.args.device is None else self.args.device)
  141. # Checks
  142. model.names = check_class_names(model.names)
  143. if self.args.half and onnx and self.device.type == 'cpu':
  144. LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
  145. self.args.half = False
  146. assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
  147. self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
  148. if self.args.optimize:
  149. assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
  150. assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
  151. if edgetpu and not LINUX:
  152. raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
  153. # Input
  154. im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
  155. file = Path(
  156. getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
  157. if file.suffix in ('.yaml', '.yml'):
  158. file = Path(file.name)
  159. # Update model
  160. model = deepcopy(model).to(self.device)
  161. for p in model.parameters():
  162. p.requires_grad = False
  163. model.eval()
  164. model.float()
  165. model = model.fuse()
  166. for m in model.modules():
  167. if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
  168. m.dynamic = self.args.dynamic
  169. m.export = True
  170. m.format = self.args.format
  171. elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
  172. # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
  173. m.forward = m.forward_split
  174. y = None
  175. for _ in range(2):
  176. y = model(im) # dry runs
  177. if self.args.half and (engine or onnx) and self.device.type != 'cpu':
  178. im, model = im.half(), model.half() # to FP16
  179. # Filter warnings
  180. warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
  181. warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
  182. warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
  183. # Assign
  184. self.im = im
  185. self.model = model
  186. self.file = file
  187. self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
  188. tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
  189. self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
  190. data = model.args['data'] if hasattr(model, 'args') and isinstance(model.args, dict) else ''
  191. description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
  192. self.metadata = {
  193. 'description': description,
  194. 'author': 'Ultralytics',
  195. 'license': 'AGPL-3.0 https://ultralytics.com/license',
  196. 'date': datetime.now().isoformat(),
  197. 'version': __version__,
  198. 'stride': int(max(model.stride)),
  199. 'task': model.task,
  200. 'batch': self.args.batch,
  201. 'imgsz': self.imgsz,
  202. 'names': model.names} # model metadata
  203. if model.task == 'pose':
  204. self.metadata['kpt_shape'] = model.model[-1].kpt_shape
  205. LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
  206. f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
  207. # Exports
  208. f = [''] * len(fmts) # exported filenames
  209. if jit or ncnn: # TorchScript
  210. f[0], _ = self.export_torchscript()
  211. if engine: # TensorRT required before ONNX
  212. f[1], _ = self.export_engine()
  213. if onnx or xml: # OpenVINO requires ONNX
  214. f[2], _ = self.export_onnx()
  215. if xml: # OpenVINO
  216. f[3], _ = self.export_openvino()
  217. if coreml: # CoreML
  218. f[4], _ = self.export_coreml()
  219. if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
  220. self.args.int8 |= edgetpu
  221. f[5], keras_model = self.export_saved_model()
  222. if pb or tfjs: # pb prerequisite to tfjs
  223. f[6], _ = self.export_pb(keras_model=keras_model)
  224. if tflite:
  225. f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
  226. if edgetpu:
  227. f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
  228. if tfjs:
  229. f[9], _ = self.export_tfjs()
  230. if paddle: # PaddlePaddle
  231. f[10], _ = self.export_paddle()
  232. if ncnn: # ncnn
  233. f[11], _ = self.export_ncnn()
  234. # Finish
  235. f = [str(x) for x in f if x] # filter out '' and None
  236. if any(f):
  237. f = str(Path(f[-1]))
  238. square = self.imgsz[0] == self.imgsz[1]
  239. s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
  240. f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
  241. imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
  242. predict_data = f'data={data}' if model.task == 'segment' and format == 'pb' else ''
  243. LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
  244. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  245. f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {predict_data}'
  246. f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {s}'
  247. f'\nVisualize: https://netron.app')
  248. self.run_callbacks('on_export_end')
  249. return f # return list of exported files/dirs
  250. @try_export
  251. def export_torchscript(self, prefix=colorstr('TorchScript:')):
  252. """YOLOv8 TorchScript model export."""
  253. LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
  254. f = self.file.with_suffix('.torchscript')
  255. ts = torch.jit.trace(self.model, self.im, strict=False)
  256. extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
  257. if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
  258. LOGGER.info(f'{prefix} optimizing for mobile...')
  259. from torch.utils.mobile_optimizer import optimize_for_mobile
  260. optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
  261. else:
  262. ts.save(str(f), _extra_files=extra_files)
  263. return f, None
  264. @try_export
  265. def export_onnx(self, prefix=colorstr('ONNX:')):
  266. """YOLOv8 ONNX export."""
  267. requirements = ['onnx>=1.12.0']
  268. if self.args.simplify:
  269. requirements += ['onnxsim>=0.4.33', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
  270. check_requirements(requirements)
  271. import onnx # noqa
  272. opset_version = self.args.opset or get_latest_opset()
  273. LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
  274. f = str(self.file.with_suffix('.onnx'))
  275. output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
  276. dynamic = self.args.dynamic
  277. if dynamic:
  278. dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
  279. if isinstance(self.model, SegmentationModel):
  280. dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400)
  281. dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
  282. elif isinstance(self.model, DetectionModel):
  283. dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400)
  284. torch.onnx.export(
  285. self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
  286. self.im.cpu() if dynamic else self.im,
  287. f,
  288. verbose=False,
  289. opset_version=opset_version,
  290. do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
  291. input_names=['images'],
  292. output_names=output_names,
  293. dynamic_axes=dynamic or None)
  294. # Checks
  295. model_onnx = onnx.load(f) # load onnx model
  296. # onnx.checker.check_model(model_onnx) # check onnx model
  297. # Simplify
  298. if self.args.simplify:
  299. try:
  300. import onnxsim
  301. LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
  302. # subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
  303. model_onnx, check = onnxsim.simplify(model_onnx)
  304. assert check, 'Simplified ONNX model could not be validated'
  305. except Exception as e:
  306. LOGGER.info(f'{prefix} simplifier failure: {e}')
  307. # Metadata
  308. for k, v in self.metadata.items():
  309. meta = model_onnx.metadata_props.add()
  310. meta.key, meta.value = k, str(v)
  311. onnx.save(model_onnx, f)
  312. return f, model_onnx
  313. @try_export
  314. def export_openvino(self, prefix=colorstr('OpenVINO:')):
  315. """YOLOv8 OpenVINO export."""
  316. check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
  317. import openvino.runtime as ov # noqa
  318. from openvino.tools import mo # noqa
  319. LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
  320. f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
  321. f_onnx = self.file.with_suffix('.onnx')
  322. f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
  323. ov_model = mo.convert_model(f_onnx,
  324. model_name=self.pretty_name,
  325. framework='onnx',
  326. compress_to_fp16=self.args.half) # export
  327. # Set RT info
  328. ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
  329. ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
  330. ov_model.set_rt_info(114, ['model_info', 'pad_value'])
  331. ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
  332. ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
  333. ov_model.set_rt_info([v.replace(' ', '_') for k, v in sorted(self.model.names.items())],
  334. ['model_info', 'labels'])
  335. if self.model.task != 'classify':
  336. ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
  337. ov.serialize(ov_model, f_ov) # save
  338. yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
  339. return f, None
  340. @try_export
  341. def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
  342. """YOLOv8 Paddle export."""
  343. check_requirements(('paddlepaddle', 'x2paddle'))
  344. import x2paddle # noqa
  345. from x2paddle.convert import pytorch2paddle # noqa
  346. LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
  347. f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
  348. pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
  349. yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
  350. return f, None
  351. @try_export
  352. def export_ncnn(self, prefix=colorstr('ncnn:')):
  353. """
  354. YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
  355. """
  356. check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
  357. import ncnn # noqa
  358. LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...')
  359. f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
  360. f_ts = self.file.with_suffix('.torchscript')
  361. pnnx_filename = 'pnnx.exe' if WINDOWS else 'pnnx'
  362. if Path(pnnx_filename).is_file():
  363. pnnx = pnnx_filename
  364. elif (ROOT / pnnx_filename).is_file():
  365. pnnx = ROOT / pnnx_filename
  366. else:
  367. LOGGER.warning(
  368. f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
  369. 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
  370. f'or in {ROOT}. See PNNX repo for full installation instructions.')
  371. _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
  372. system = 'macos' if MACOS else 'ubuntu' if LINUX else 'windows' # operating system
  373. asset = [x for x in assets if system in x][0] if assets else \
  374. f'https://github.com/pnnx/pnnx/releases/download/20230816/pnnx-20230816-{system}.zip' # fallback
  375. asset = attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
  376. unzip_dir = Path(asset).with_suffix('')
  377. pnnx = ROOT / pnnx_filename # new location
  378. (unzip_dir / pnnx_filename).rename(pnnx) # move binary to ROOT
  379. shutil.rmtree(unzip_dir) # delete unzip dir
  380. Path(asset).unlink() # delete zip
  381. pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
  382. use_ncnn = True
  383. ncnn_args = [
  384. f'ncnnparam={f / "model.ncnn.param"}',
  385. f'ncnnbin={f / "model.ncnn.bin"}',
  386. f'ncnnpy={f / "model_ncnn.py"}', ] if use_ncnn else []
  387. use_pnnx = False
  388. pnnx_args = [
  389. f'pnnxparam={f / "model.pnnx.param"}',
  390. f'pnnxbin={f / "model.pnnx.bin"}',
  391. f'pnnxpy={f / "model_pnnx.py"}',
  392. f'pnnxonnx={f / "model.pnnx.onnx"}', ] if use_pnnx else []
  393. cmd = [
  394. str(pnnx),
  395. str(f_ts),
  396. *ncnn_args,
  397. *pnnx_args,
  398. f'fp16={int(self.args.half)}',
  399. f'device={self.device.type}',
  400. f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
  401. f.mkdir(exist_ok=True) # make ncnn_model directory
  402. LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
  403. subprocess.run(cmd, check=True)
  404. for f_debug in 'debug.bin', 'debug.param', 'debug2.bin', 'debug2.param': # remove debug files
  405. Path(f_debug).unlink(missing_ok=True)
  406. yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
  407. return str(f), None
  408. @try_export
  409. def export_coreml(self, prefix=colorstr('CoreML:')):
  410. """YOLOv8 CoreML export."""
  411. mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested
  412. check_requirements('coremltools>=6.0,<=6.2' if mlmodel else 'coremltools>=7.0.b1')
  413. import coremltools as ct # noqa
  414. LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  415. f = self.file.with_suffix('.mlmodel' if mlmodel else '.mlpackage')
  416. if f.is_dir():
  417. shutil.rmtree(f)
  418. bias = [0.0, 0.0, 0.0]
  419. scale = 1 / 255
  420. classifier_config = None
  421. if self.model.task == 'classify':
  422. classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
  423. model = self.model
  424. elif self.model.task == 'detect':
  425. model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
  426. else:
  427. if self.args.nms:
  428. LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.")
  429. # TODO CoreML Segment and Pose model pipelining
  430. model = self.model
  431. ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
  432. ct_model = ct.convert(ts,
  433. inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
  434. classifier_config=classifier_config,
  435. convert_to='neuralnetwork' if mlmodel else 'mlprogram')
  436. bits, mode = (8, 'kmeans') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
  437. if bits < 32:
  438. if 'kmeans' in mode:
  439. check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
  440. if mlmodel:
  441. ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
  442. elif bits == 8: # mlprogram already quantized to FP16
  443. import coremltools.optimize.coreml as cto
  444. op_config = cto.OpPalettizerConfig(mode='kmeans', nbits=bits, weight_threshold=512)
  445. config = cto.OptimizationConfig(global_config=op_config)
  446. ct_model = cto.palettize_weights(ct_model, config=config)
  447. if self.args.nms and self.model.task == 'detect':
  448. if mlmodel:
  449. import platform
  450. # coremltools<=6.2 NMS export requires Python<3.11
  451. check_version(platform.python_version(), '<3.11', name='Python ', hard=True)
  452. weights_dir = None
  453. else:
  454. ct_model.save(str(f)) # save otherwise weights_dir does not exist
  455. weights_dir = str(f / 'Data/com.apple.CoreML/weights')
  456. ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
  457. m = self.metadata # metadata dict
  458. ct_model.short_description = m.pop('description')
  459. ct_model.author = m.pop('author')
  460. ct_model.license = m.pop('license')
  461. ct_model.version = m.pop('version')
  462. ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
  463. try:
  464. ct_model.save(str(f)) # save *.mlpackage
  465. except Exception as e:
  466. LOGGER.warning(
  467. f'{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. '
  468. f'Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.')
  469. f = f.with_suffix('.mlmodel')
  470. ct_model.save(str(f))
  471. return f, ct_model
  472. @try_export
  473. def export_engine(self, prefix=colorstr('TensorRT:')):
  474. """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
  475. assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
  476. try:
  477. import tensorrt as trt # noqa
  478. except ImportError:
  479. if LINUX:
  480. check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
  481. import tensorrt as trt # noqa
  482. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  483. self.args.simplify = True
  484. f_onnx, _ = self.export_onnx()
  485. LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
  486. assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
  487. f = self.file.with_suffix('.engine') # TensorRT engine file
  488. logger = trt.Logger(trt.Logger.INFO)
  489. if self.args.verbose:
  490. logger.min_severity = trt.Logger.Severity.VERBOSE
  491. builder = trt.Builder(logger)
  492. config = builder.create_builder_config()
  493. config.max_workspace_size = self.args.workspace * 1 << 30
  494. # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
  495. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  496. network = builder.create_network(flag)
  497. parser = trt.OnnxParser(network, logger)
  498. if not parser.parse_from_file(f_onnx):
  499. raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
  500. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  501. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  502. for inp in inputs:
  503. LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
  504. for out in outputs:
  505. LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
  506. if self.args.dynamic:
  507. shape = self.im.shape
  508. if shape[0] <= 1:
  509. LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
  510. profile = builder.create_optimization_profile()
  511. for inp in inputs:
  512. profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
  513. config.add_optimization_profile(profile)
  514. LOGGER.info(
  515. f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
  516. if builder.platform_has_fast_fp16 and self.args.half:
  517. config.set_flag(trt.BuilderFlag.FP16)
  518. # Write file
  519. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  520. # Metadata
  521. meta = json.dumps(self.metadata)
  522. t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
  523. t.write(meta.encode())
  524. # Model
  525. t.write(engine.serialize())
  526. return f, None
  527. @try_export
  528. def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
  529. """YOLOv8 TensorFlow SavedModel export."""
  530. cuda = torch.cuda.is_available()
  531. try:
  532. import tensorflow as tf # noqa
  533. except ImportError:
  534. check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
  535. import tensorflow as tf # noqa
  536. check_requirements(
  537. ('onnx', 'onnx2tf>=1.15.4', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.33', 'onnx_graphsurgeon>=0.3.26',
  538. 'tflite_support', 'onnxruntime-gpu' if cuda else 'onnxruntime'),
  539. cmds='--extra-index-url https://pypi.ngc.nvidia.com') # onnx_graphsurgeon only on NVIDIA
  540. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  541. f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
  542. if f.is_dir():
  543. import shutil
  544. shutil.rmtree(f) # delete output folder
  545. # Export to ONNX
  546. self.args.simplify = True
  547. f_onnx, _ = self.export_onnx()
  548. # Export to TF
  549. tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
  550. if self.args.int8:
  551. verbosity = '--verbosity info'
  552. if self.args.data:
  553. import numpy as np
  554. from ultralytics.data.dataset import YOLODataset
  555. from ultralytics.data.utils import check_det_dataset
  556. # Generate calibration data for integer quantization
  557. LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
  558. data = check_det_dataset(self.args.data)
  559. dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
  560. images = []
  561. n_images = 100 # maximum number of images
  562. for n, batch in enumerate(dataset):
  563. if n >= n_images:
  564. break
  565. im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
  566. images.append(im)
  567. f.mkdir()
  568. images = torch.cat(images, 0).float()
  569. # mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
  570. # std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
  571. np.save(str(tmp_file), images.numpy()) # BHWC
  572. int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
  573. else:
  574. int8 = '-oiqt -qt per-tensor'
  575. else:
  576. verbosity = '--non_verbose'
  577. int8 = ''
  578. cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip()
  579. LOGGER.info(f"{prefix} running '{cmd}'")
  580. subprocess.run(cmd, shell=True)
  581. yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
  582. # Remove/rename TFLite models
  583. if self.args.int8:
  584. tmp_file.unlink(missing_ok=True)
  585. for file in f.rglob('*_dynamic_range_quant.tflite'):
  586. file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
  587. for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
  588. file.unlink() # delete extra fp16 activation TFLite files
  589. # Add TFLite metadata
  590. for file in f.rglob('*.tflite'):
  591. f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
  592. return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
  593. @try_export
  594. def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
  595. """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
  596. import tensorflow as tf # noqa
  597. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
  598. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  599. f = self.file.with_suffix('.pb')
  600. m = tf.function(lambda x: keras_model(x)) # full model
  601. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  602. frozen_func = convert_variables_to_constants_v2(m)
  603. frozen_func.graph.as_graph_def()
  604. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  605. return f, None
  606. @try_export
  607. def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
  608. """YOLOv8 TensorFlow Lite export."""
  609. import tensorflow as tf # noqa
  610. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  611. saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
  612. if self.args.int8:
  613. f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
  614. elif self.args.half:
  615. f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
  616. else:
  617. f = saved_model / f'{self.file.stem}_float32.tflite'
  618. return str(f), None
  619. @try_export
  620. def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
  621. """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
  622. LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
  623. cmd = 'edgetpu_compiler --version'
  624. help_url = 'https://coral.ai/docs/edgetpu/compiler/'
  625. assert LINUX, f'export only supported on Linux. See {help_url}'
  626. if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
  627. LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
  628. sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
  629. for c in (
  630. 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
  631. 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
  632. 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
  633. subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
  634. ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
  635. LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
  636. f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
  637. cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
  638. LOGGER.info(f"{prefix} running '{cmd}'")
  639. subprocess.run(cmd, shell=True)
  640. self._add_tflite_metadata(f)
  641. return f, None
  642. @try_export
  643. def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
  644. """YOLOv8 TensorFlow.js export."""
  645. check_requirements('tensorflowjs')
  646. import tensorflow as tf
  647. import tensorflowjs as tfjs # noqa
  648. LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
  649. f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
  650. f_pb = str(self.file.with_suffix('.pb')) # *.pb path
  651. gd = tf.Graph().as_graph_def() # TF GraphDef
  652. with open(f_pb, 'rb') as file:
  653. gd.ParseFromString(file.read())
  654. outputs = ','.join(gd_outputs(gd))
  655. LOGGER.info(f'\n{prefix} output node names: {outputs}')
  656. with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
  657. cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} "{fpb_}" "{f_}"'
  658. LOGGER.info(f"{prefix} running '{cmd}'")
  659. subprocess.run(cmd, shell=True)
  660. if ' ' in str(f):
  661. LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
  662. # f_json = Path(f) / 'model.json' # *.json path
  663. # with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
  664. # subst = re.sub(
  665. # r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
  666. # r'"Identity.?.?": {"name": "Identity.?.?"}, '
  667. # r'"Identity.?.?": {"name": "Identity.?.?"}, '
  668. # r'"Identity.?.?": {"name": "Identity.?.?"}}}',
  669. # r'{"outputs": {"Identity": {"name": "Identity"}, '
  670. # r'"Identity_1": {"name": "Identity_1"}, '
  671. # r'"Identity_2": {"name": "Identity_2"}, '
  672. # r'"Identity_3": {"name": "Identity_3"}}}',
  673. # f_json.read_text(),
  674. # )
  675. # j.write(subst)
  676. yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
  677. return f, None
  678. def _add_tflite_metadata(self, file):
  679. """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
  680. from tflite_support import flatbuffers # noqa
  681. from tflite_support import metadata as _metadata # noqa
  682. from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
  683. # Create model info
  684. model_meta = _metadata_fb.ModelMetadataT()
  685. model_meta.name = self.metadata['description']
  686. model_meta.version = self.metadata['version']
  687. model_meta.author = self.metadata['author']
  688. model_meta.license = self.metadata['license']
  689. # Label file
  690. tmp_file = Path(file).parent / 'temp_meta.txt'
  691. with open(tmp_file, 'w') as f:
  692. f.write(str(self.metadata))
  693. label_file = _metadata_fb.AssociatedFileT()
  694. label_file.name = tmp_file.name
  695. label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
  696. # Create input info
  697. input_meta = _metadata_fb.TensorMetadataT()
  698. input_meta.name = 'image'
  699. input_meta.description = 'Input image to be detected.'
  700. input_meta.content = _metadata_fb.ContentT()
  701. input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
  702. input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
  703. input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
  704. # Create output info
  705. output1 = _metadata_fb.TensorMetadataT()
  706. output1.name = 'output'
  707. output1.description = 'Coordinates of detected objects, class labels, and confidence score'
  708. output1.associatedFiles = [label_file]
  709. if self.model.task == 'segment':
  710. output2 = _metadata_fb.TensorMetadataT()
  711. output2.name = 'output'
  712. output2.description = 'Mask protos'
  713. output2.associatedFiles = [label_file]
  714. # Create subgraph info
  715. subgraph = _metadata_fb.SubGraphMetadataT()
  716. subgraph.inputTensorMetadata = [input_meta]
  717. subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
  718. model_meta.subgraphMetadata = [subgraph]
  719. b = flatbuffers.Builder(0)
  720. b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
  721. metadata_buf = b.Output()
  722. populator = _metadata.MetadataPopulator.with_model_file(str(file))
  723. populator.load_metadata_buffer(metadata_buf)
  724. populator.load_associated_files([str(tmp_file)])
  725. populator.populate()
  726. tmp_file.unlink()
  727. def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr('CoreML Pipeline:')):
  728. """YOLOv8 CoreML pipeline."""
  729. import coremltools as ct # noqa
  730. LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
  731. _, _, h, w = list(self.im.shape) # BCHW
  732. # Output shapes
  733. spec = model.get_spec()
  734. out0, out1 = iter(spec.description.output)
  735. if MACOS:
  736. from PIL import Image
  737. img = Image.new('RGB', (w, h)) # w=192, h=320
  738. out = model.predict({'image': img})
  739. out0_shape = out[out0.name].shape # (3780, 80)
  740. out1_shape = out[out1.name].shape # (3780, 4)
  741. else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
  742. out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
  743. out1_shape = self.output_shape[2], 4 # (3780, 4)
  744. # Checks
  745. names = self.metadata['names']
  746. nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
  747. _, nc = out0_shape # number of anchors, number of classes
  748. # _, nc = out0.type.multiArrayType.shape
  749. assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
  750. # Define output shapes (missing)
  751. out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
  752. out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
  753. # spec.neuralNetwork.preprocessing[0].featureName = '0'
  754. # Flexible input shapes
  755. # from coremltools.models.neural_network import flexible_shape_utils
  756. # s = [] # shapes
  757. # s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
  758. # s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
  759. # flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
  760. # r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
  761. # r.add_height_range((192, 640))
  762. # r.add_width_range((192, 640))
  763. # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
  764. # Print
  765. # print(spec.description)
  766. # Model from spec
  767. model = ct.models.MLModel(spec, weights_dir=weights_dir)
  768. # 3. Create NMS protobuf
  769. nms_spec = ct.proto.Model_pb2.Model()
  770. nms_spec.specificationVersion = 5
  771. for i in range(2):
  772. decoder_output = model._spec.description.output[i].SerializeToString()
  773. nms_spec.description.input.add()
  774. nms_spec.description.input[i].ParseFromString(decoder_output)
  775. nms_spec.description.output.add()
  776. nms_spec.description.output[i].ParseFromString(decoder_output)
  777. nms_spec.description.output[0].name = 'confidence'
  778. nms_spec.description.output[1].name = 'coordinates'
  779. output_sizes = [nc, 4]
  780. for i in range(2):
  781. ma_type = nms_spec.description.output[i].type.multiArrayType
  782. ma_type.shapeRange.sizeRanges.add()
  783. ma_type.shapeRange.sizeRanges[0].lowerBound = 0
  784. ma_type.shapeRange.sizeRanges[0].upperBound = -1
  785. ma_type.shapeRange.sizeRanges.add()
  786. ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
  787. ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
  788. del ma_type.shape[:]
  789. nms = nms_spec.nonMaximumSuppression
  790. nms.confidenceInputFeatureName = out0.name # 1x507x80
  791. nms.coordinatesInputFeatureName = out1.name # 1x507x4
  792. nms.confidenceOutputFeatureName = 'confidence'
  793. nms.coordinatesOutputFeatureName = 'coordinates'
  794. nms.iouThresholdInputFeatureName = 'iouThreshold'
  795. nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
  796. nms.iouThreshold = 0.45
  797. nms.confidenceThreshold = 0.25
  798. nms.pickTop.perClass = True
  799. nms.stringClassLabels.vector.extend(names.values())
  800. nms_model = ct.models.MLModel(nms_spec)
  801. # 4. Pipeline models together
  802. pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
  803. ('iouThreshold', ct.models.datatypes.Double()),
  804. ('confidenceThreshold', ct.models.datatypes.Double())],
  805. output_features=['confidence', 'coordinates'])
  806. pipeline.add_model(model)
  807. pipeline.add_model(nms_model)
  808. # Correct datatypes
  809. pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
  810. pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
  811. pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
  812. # Update metadata
  813. pipeline.spec.specificationVersion = 5
  814. pipeline.spec.description.metadata.userDefined.update({
  815. 'IoU threshold': str(nms.iouThreshold),
  816. 'Confidence threshold': str(nms.confidenceThreshold)})
  817. # Save the model
  818. model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
  819. model.input_description['image'] = 'Input image'
  820. model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
  821. model.input_description['confidenceThreshold'] = \
  822. f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
  823. model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
  824. model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
  825. LOGGER.info(f'{prefix} pipeline success')
  826. return model
  827. def add_callback(self, event: str, callback):
  828. """
  829. Appends the given callback.
  830. """
  831. self.callbacks[event].append(callback)
  832. def run_callbacks(self, event: str):
  833. """Execute all callbacks for a given event."""
  834. for callback in self.callbacks.get(event, []):
  835. callback(self)
  836. class IOSDetectModel(torch.nn.Module):
  837. """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
  838. def __init__(self, model, im):
  839. """Initialize the IOSDetectModel class with a YOLO model and example image."""
  840. super().__init__()
  841. _, _, h, w = im.shape # batch, channel, height, width
  842. self.model = model
  843. self.nc = len(model.names) # number of classes
  844. if w == h:
  845. self.normalize = 1.0 / w # scalar
  846. else:
  847. self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
  848. def forward(self, x):
  849. """Normalize predictions of object detection model with input size-dependent factors."""
  850. xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
  851. return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)