123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import contextlib
- import re
- import shutil
- import sys
- from difflib import get_close_matches
- from pathlib import Path
- from types import SimpleNamespace
- from typing import Dict, List, Union
- from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
- IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
- yaml_print)
- # Define valid tasks and modes
- MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
- TASKS = 'detect', 'segment', 'classify', 'pose'
- TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
- TASK2MODEL = {
- 'detect': 'yolov8n.pt',
- 'segment': 'yolov8n-seg.pt',
- 'classify': 'yolov8n-cls.pt',
- 'pose': 'yolov8n-pose.pt'}
- TASK2METRIC = {
- 'detect': 'metrics/mAP50-95(B)',
- 'segment': 'metrics/mAP50-95(M)',
- 'classify': 'metrics/accuracy_top1',
- 'pose': 'metrics/mAP50-95(P)'}
- CLI_HELP_MSG = \
- f"""
- Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
- yolo TASK MODE ARGS
- Where TASK (optional) is one of {TASKS}
- MODE (required) is one of {MODES}
- ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
- See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
- 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
- yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
- 2. Predict a YouTube video using a pretrained segmentation model at image size 320:
- yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
- 3. Val a pretrained detection model at batch-size 1 and image size 640:
- yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
- 4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
- yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
- 5. Run special commands:
- yolo help
- yolo checks
- yolo version
- yolo settings
- yolo copy-cfg
- yolo cfg
- Docs: https://docs.ultralytics.com
- Community: https://community.ultralytics.com
- GitHub: https://github.com/ultralytics/ultralytics
- """
- # Define keys for arg type checks
- CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
- CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
- 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
- 'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
- CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
- 'line_width', 'workspace', 'nbs', 'save_period')
- CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
- 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
- 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
- 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
- def cfg2dict(cfg):
- """
- Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
- Args:
- cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted to a dictionary.
- Returns:
- cfg (dict): Configuration object in dictionary format.
- """
- if isinstance(cfg, (str, Path)):
- cfg = yaml_load(cfg) # load dict
- elif isinstance(cfg, SimpleNamespace):
- cfg = vars(cfg) # convert to dict
- return cfg
- def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
- """
- Load and merge configuration data from a file or dictionary.
- Args:
- cfg (str | Path | Dict | SimpleNamespace): Configuration data.
- overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
- Returns:
- (SimpleNamespace): Training arguments namespace.
- """
- cfg = cfg2dict(cfg)
- # Merge overrides
- if overrides:
- overrides = cfg2dict(overrides)
- if 'save_dir' not in cfg:
- overrides.pop('save_dir', None) # special override keys to ignore
- check_dict_alignment(cfg, overrides)
- cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
- # Special handling for numeric project/name
- for k in 'project', 'name':
- if k in cfg and isinstance(cfg[k], (int, float)):
- cfg[k] = str(cfg[k])
- if cfg.get('name') == 'model': # assign model to 'name' arg
- cfg['name'] = cfg.get('model', '').split('.')[0]
- LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
- # Type and Value checks
- for k, v in cfg.items():
- if v is not None: # None values may be from optional args
- if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
- elif k in CFG_FRACTION_KEYS:
- if not isinstance(v, (int, float)):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
- if not (0.0 <= v <= 1.0):
- raise ValueError(f"'{k}={v}' is an invalid value. "
- f"Valid '{k}' values are between 0.0 and 1.0.")
- elif k in CFG_INT_KEYS and not isinstance(v, int):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"'{k}' must be an int (i.e. '{k}=8')")
- elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
- # Return instance
- return IterableSimpleNamespace(**cfg)
- def _handle_deprecation(custom):
- """Hardcoded function to handle deprecated config keys"""
- for key in custom.copy().keys():
- if key == 'hide_labels':
- deprecation_warn(key, 'show_labels')
- custom['show_labels'] = custom.pop('hide_labels') == 'False'
- if key == 'hide_conf':
- deprecation_warn(key, 'show_conf')
- custom['show_conf'] = custom.pop('hide_conf') == 'False'
- if key == 'line_thickness':
- deprecation_warn(key, 'line_width')
- custom['line_width'] = custom.pop('line_thickness')
- return custom
- def check_dict_alignment(base: Dict, custom: Dict, e=None):
- """
- This function checks for any mismatched keys between a custom configuration list and a base configuration list.
- If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
- Args:
- custom (dict): a dictionary of custom configuration options
- base (dict): a dictionary of base configuration options
- """
- custom = _handle_deprecation(custom)
- base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
- mismatched = [k for k in custom_keys if k not in base_keys]
- if mismatched:
- string = ''
- for x in mismatched:
- matches = get_close_matches(x, base_keys) # key list
- matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
- match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
- string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
- raise SyntaxError(string + CLI_HELP_MSG) from e
- def merge_equals_args(args: List[str]) -> List[str]:
- """
- Merges arguments around isolated '=' args in a list of strings.
- The function considers cases where the first argument ends with '=' or the second starts with '=',
- as well as when the middle one is an equals sign.
- Args:
- args (List[str]): A list of strings where each element is an argument.
- Returns:
- List[str]: A list of strings where the arguments around isolated '=' are merged.
- """
- new_args = []
- for i, arg in enumerate(args):
- if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
- new_args[-1] += f'={args[i + 1]}'
- del args[i + 1]
- elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
- new_args.append(f'{arg}{args[i + 1]}')
- del args[i + 1]
- elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
- new_args[-1] += arg
- else:
- new_args.append(arg)
- return new_args
- def handle_yolo_hub(args: List[str]) -> None:
- """
- Handle Ultralytics HUB command-line interface (CLI) commands.
- This function processes Ultralytics HUB CLI commands such as login and logout.
- It should be called when executing a script with arguments related to HUB authentication.
- Args:
- args (List[str]): A list of command line arguments
- Example:
- ```bash
- python my_script.py hub login your_api_key
- ```
- """
- from ultralytics import hub
- if args[0] == 'login':
- key = args[1] if len(args) > 1 else ''
- # Log in to Ultralytics HUB using the provided API key
- hub.login(key)
- elif args[0] == 'logout':
- # Log out from Ultralytics HUB
- hub.logout()
- def handle_yolo_settings(args: List[str]) -> None:
- """
- Handle YOLO settings command-line interface (CLI) commands.
- This function processes YOLO settings CLI commands such as reset.
- It should be called when executing a script with arguments related to YOLO settings management.
- Args:
- args (List[str]): A list of command line arguments for YOLO settings management.
- Example:
- ```bash
- python my_script.py yolo settings reset
- ```
- """
- url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL
- try:
- if any(args):
- if args[0] == 'reset':
- SETTINGS_YAML.unlink() # delete the settings file
- SETTINGS.reset() # create new settings
- LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
- else: # save a new setting
- new = dict(parse_key_value_pair(a) for a in args)
- check_dict_alignment(SETTINGS, new)
- SETTINGS.update(new)
- LOGGER.info(f'💡 Learn about settings at {url}')
- yaml_print(SETTINGS_YAML) # print the current settings
- except Exception as e:
- LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
- def parse_key_value_pair(pair):
- """Parse one 'key=value' pair and return key and value."""
- re.sub(r' *= *', '=', pair) # remove spaces around equals sign
- k, v = pair.split('=', 1) # split on first '=' sign
- assert v, f"missing '{k}' value"
- return k, smart_value(v)
- def smart_value(v):
- """Convert a string to an underlying type such as int, float, bool, etc."""
- if v.lower() == 'none':
- return None
- elif v.lower() == 'true':
- return True
- elif v.lower() == 'false':
- return False
- else:
- with contextlib.suppress(Exception):
- return eval(v)
- return v
- def entrypoint(debug=''):
- """
- This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
- to the package.
- This function allows for:
- - passing mandatory YOLO args as a list of strings
- - specifying the task to be performed, either 'detect', 'segment' or 'classify'
- - specifying the mode, either 'train', 'val', 'test', or 'predict'
- - running special modes like 'checks'
- - passing overrides to the package's configuration
- It uses the package's default cfg and initializes it using the passed overrides.
- Then it calls the CLI function with the composed cfg
- """
- args = (debug.split(' ') if debug else sys.argv)[1:]
- if not args: # no arguments passed
- LOGGER.info(CLI_HELP_MSG)
- return
- special = {
- 'help': lambda: LOGGER.info(CLI_HELP_MSG),
- 'checks': checks.check_yolo,
- 'version': lambda: LOGGER.info(__version__),
- 'settings': lambda: handle_yolo_settings(args[1:]),
- 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
- 'hub': lambda: handle_yolo_hub(args[1:]),
- 'login': lambda: handle_yolo_hub(args),
- 'copy-cfg': copy_default_cfg}
- full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
- # Define common mis-uses of special commands, i.e. -h, -help, --help
- special.update({k[0]: v for k, v in special.items()}) # singular
- special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
- special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
- overrides = {} # basic overrides, i.e. imgsz=320
- for a in merge_equals_args(args): # merge spaces around '=' sign
- if a.startswith('--'):
- LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
- a = a[2:]
- if a.endswith(','):
- LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
- a = a[:-1]
- if '=' in a:
- try:
- k, v = parse_key_value_pair(a)
- if k == 'cfg': # custom.yaml passed
- LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
- overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
- else:
- overrides[k] = v
- except (NameError, SyntaxError, ValueError, AssertionError) as e:
- check_dict_alignment(full_args_dict, {a: ''}, e)
- elif a in TASKS:
- overrides['task'] = a
- elif a in MODES:
- overrides['mode'] = a
- elif a.lower() in special:
- special[a.lower()]()
- return
- elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
- overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
- elif a in DEFAULT_CFG_DICT:
- raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
- f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
- else:
- check_dict_alignment(full_args_dict, {a: ''})
- # Check keys
- check_dict_alignment(full_args_dict, overrides)
- # Mode
- mode = overrides.get('mode')
- if mode is None:
- mode = DEFAULT_CFG.mode or 'predict'
- LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
- elif mode not in MODES:
- if mode not in ('checks', checks):
- raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
- LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
- checks.check_yolo()
- return
- # Task
- task = overrides.pop('task', None)
- if task:
- if task not in TASKS:
- raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
- if 'model' not in overrides:
- overrides['model'] = TASK2MODEL[task]
- # Model
- model = overrides.pop('model', DEFAULT_CFG.model)
- if model is None:
- model = 'yolov8n.pt'
- LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
- overrides['model'] = model
- if 'rtdetr' in model.lower(): # guess architecture
- from ultralytics import RTDETR
- model = RTDETR(model) # no task argument
- elif 'fastsam' in model.lower():
- from ultralytics import FastSAM
- model = FastSAM(model)
- elif 'sam' in model.lower():
- from ultralytics import SAM
- model = SAM(model)
- else:
- from ultralytics import YOLO
- model = YOLO(model, task=task)
- if isinstance(overrides.get('pretrained'), str):
- model.load(overrides['pretrained'])
- # Task Update
- if task != model.task:
- if task:
- LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
- f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
- task = model.task
- # Mode
- if mode in ('predict', 'track') and 'source' not in overrides:
- overrides['source'] = DEFAULT_CFG.source or ASSETS
- LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
- elif mode in ('train', 'val'):
- if 'data' not in overrides and 'resume' not in overrides:
- overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
- LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
- elif mode == 'export':
- if 'format' not in overrides:
- overrides['format'] = DEFAULT_CFG.format or 'torchscript'
- LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
- # Run command in python
- # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
- getattr(model, mode)(**overrides) # default args from model
- # Special modes --------------------------------------------------------------------------------------------------------
- def copy_default_cfg():
- """Copy and create a new default configuration file with '_copy' appended to its name."""
- new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
- shutil.copy2(DEFAULT_CFG_PATH, new_file)
- LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
- f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
- if __name__ == '__main__':
- # Example: entrypoint(debug='yolo predict model=yolov8n.pt')
- entrypoint(debug='')
|