|
| 1 | +from typing import Iterable |
| 2 | +import argparse |
| 3 | +import logging |
| 4 | + |
| 5 | +from catalyst import utils |
| 6 | +from catalyst.registry import REGISTRY |
| 7 | + |
| 8 | + |
| 9 | +def parse_args(): |
| 10 | + """Parses the command line arguments and returns arguments and config.""" |
| 11 | + parser = argparse.ArgumentParser() |
| 12 | + parser.add_argument( |
| 13 | + "--config", |
| 14 | + "--configs", |
| 15 | + "-C", |
| 16 | + nargs="+", |
| 17 | + default=("config.yml",), |
| 18 | + type=str, |
| 19 | + help="path to config/configs", |
| 20 | + metavar="CONFIG_PATH", |
| 21 | + dest="configs", |
| 22 | + ) |
| 23 | + |
| 24 | + utils.boolean_flag( |
| 25 | + parser, |
| 26 | + "deterministic", |
| 27 | + default=None, |
| 28 | + help="Deterministic mode if running in CuDNN backend", |
| 29 | + ) |
| 30 | + utils.boolean_flag(parser, "benchmark", default=None, help="Use CuDNN benchmark") |
| 31 | + |
| 32 | + args, unknown_args = parser.parse_known_args() |
| 33 | + return vars(args), unknown_args |
| 34 | + |
| 35 | + |
| 36 | +def run_from_config( |
| 37 | + configs: Iterable[str], |
| 38 | + deterministic: bool = None, |
| 39 | + benchmark: bool = None, |
| 40 | +) -> None: |
| 41 | + """Creates Runner from YAML configs and runs experiment.""" |
| 42 | + logger = logging.getLogger(__name__) |
| 43 | + |
| 44 | + # there is no way to set deterministic/benchmark flags with a runner, |
| 45 | + # so do it manually |
| 46 | + utils.prepare_cudnn(deterministic, benchmark) |
| 47 | + |
| 48 | + config = {} |
| 49 | + for config_path in configs: |
| 50 | + config_part = utils.load_config(config_path, ordered=True) |
| 51 | + config = utils.merge_dicts(config, config_part) |
| 52 | + # config_copy = copy.deepcopy(config) |
| 53 | + |
| 54 | + experiment_params = REGISTRY.get_from_params(**config) |
| 55 | + |
| 56 | + runner = experiment_params["runner"] |
| 57 | + for stage_params in experiment_params["run"]: |
| 58 | + name = stage_params.pop("_call_") |
| 59 | + func = getattr(runner, name) |
| 60 | + |
| 61 | + result = func(**stage_params) |
| 62 | + if result is not None: |
| 63 | + logger.info(f"{name}:\n{result}") |
| 64 | + |
| 65 | + # TODO: check if needed |
| 66 | + # logdir = getattr(runner, "logdir", getattr(runner, "_logdir"), None) |
| 67 | + # if logdir and utils.get_rank() <= 0: |
| 68 | + # utils.dump_environment(logdir=logdir, config=config_copy, configs_path=configs) |
| 69 | + |
| 70 | + |
| 71 | +def main(): |
| 72 | + """Runs the ``catalyst-run`` script.""" |
| 73 | + kwargs, unknown_args = parse_args() |
| 74 | + run_from_config(**kwargs) |
| 75 | + |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + main() |
0 commit comments