解决 LLaMA-Factory 中 argparse 格式化占位符错误(TypeError: not enough arguments for format string)

在配置llama_factory的时候,使用llamafactory-cli train -h来校验的时候,碰到的bug:

llamafactory-cli train -h
[INFO|2026-04-28 16:18:08] llamafactory.launcher:144 >> Initializing 2 distributed tasks at: 127.0.0.1:57113
W0428 16:18:10.084000 3765674 site-packages/torch/distributed/run.py:766] 
W0428 16:18:10.084000 3765674 site-packages/torch/distributed/run.py:766] *****************************************
W0428 16:18:10.084000 3765674 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0428 16:18:10.084000 3765674 site-packages/torch/distributed/run.py:766] *****************************************
Traceback (most recent call last):
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/launcher.py", line 185, in <module>
    run_exp()
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/train/tuner.py", line 130, in run_exp
    get_train_args(args)
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/hparams/parser.py", line 290, in get_train_args
    model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
                                                                             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/hparams/parser.py", line 244, in _parse_train_args
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/hparams/parser.py", line 93, in _parse_args
    (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/transformers/hf_argparser.py", line 338, in parse_args_into_dataclasses
    namespace, remaining_args = self.parse_known_args(args=args)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 1907, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2128, in _parse_known_args
    start_index = consume_optional(start_index)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2068, in consume_optional
    take_action(action, args, option_string)
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 1983, in take_action
    action(self, namespace, argument_values, option_string)
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 1122, in __call__
    parser.print_help()
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2611, in print_help
    self._print_message(self.format_help(), file)
                        ^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2595, in format_help
    return formatter.format_help()
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 287, in format_help
    help = self._root_section.format_help()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in format_help
    item_help = join([func(*args) for func, args in self.items])
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in <listcomp>
    item_help = join([func(*args) for func, args in self.items])
                      ^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in format_help
    item_help = join([func(*args) for func, args in self.items])
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in <listcomp>
    item_help = join([func(*args) for func, args in self.items])
                      ^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 546, in _format_action
    help_text = self._expand_help(action)
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 643, in _expand_help
    return self._get_help_string(action) % params
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~
TypeError: not enough arguments for format string
Traceback (most recent call last):
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/launcher.py", line 185, in <module>
    run_exp()
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/train/tuner.py", line 130, in run_exp
    get_train_args(args)
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/hparams/parser.py", line 290, in get_train_args
    model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
                                                                             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/hparams/parser.py", line 244, in _parse_train_args
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/hparams/parser.py", line 93, in _parse_args
    (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/transformers/hf_argparser.py", line 338, in parse_args_into_dataclasses
    namespace, remaining_args = self.parse_known_args(args=args)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 1907, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2128, in _parse_known_args
    start_index = consume_optional(start_index)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2068, in consume_optional
    take_action(action, args, option_string)
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 1983, in take_action
    action(self, namespace, argument_values, option_string)
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 1122, in __call__
    parser.print_help()
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2611, in print_help
    self._print_message(self.format_help(), file)
                        ^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 2595, in format_help
    return formatter.format_help()
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 287, in format_help
    help = self._root_section.format_help()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in format_help
    item_help = join([func(*args) for func, args in self.items])
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in <listcomp>
    item_help = join([func(*args) for func, args in self.items])
                      ^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in format_help
    item_help = join([func(*args) for func, args in self.items])
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 217, in <listcomp>
    item_help = join([func(*args) for func, args in self.items])
                      ^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 546, in _format_action
    help_text = self._expand_help(action)
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/argparse.py", line 643, in _expand_help
    return self._get_help_string(action) % params
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~
TypeError: not enough arguments for format string
E0428 16:18:17.516000 3765674 site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 0 (pid: 3765958) of binary: /home/xyz/anaconda3/envs/llama_factory/bin/python3.11
Traceback (most recent call last):
  File "/home/xyz/anaconda3/envs/llama_factory/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in main
    run(args)
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/torch/distributed/run.py", line 883, in run
    elastic_launch(
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 139, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/home/xyz/llm/LLaMA-Factory/src/llamafactory/launcher.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2026-04-28_16:18:17
  host      : xyz-Super-Server
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3765959)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-04-28_16:18:17
  host      : xyz-Super-Server
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3765958)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Traceback (most recent call last):
  File "/home/xyz/anaconda3/envs/llama_factory/bin/llamafactory-cli", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/cli.py", line 24, in main
    launcher.launch()
  File "/home/xyz/llm/LLaMA-Factory/src/llamafactory/launcher.py", line 115, in launch
    process = subprocess.run(
              ^^^^^^^^^^^^^^^
  File "/home/xyz/anaconda3/envs/llama_factory/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['torchrun', '--nnodes', '1', '--node_rank', '0', '--nproc_per_node', '2', '--master_addr', '127.0.0.1', '--master_port', '57113', '/home/xyz/llm/LLaMA-Factory/src/llamafactory/launcher.py', '-h']' returned non-zero exit status 1.

错误分析

在执行 llamafactory-cli train -h 时,argparse 解析帮助文本时触发了错误,错误信息具体为:

TypeError: not enough arguments for format string

通过查看错误的堆栈信息,可以发现错误发生在 argparse.py 的 _expand_help 方法中,错误触发的位置为:

return self._get_help_string(action) % params

argparse 使用 % 作为格式化占位符,在解析帮助文本时,试图将字符串中的 % 替换为实际的参数。当帮助文本中包含如 ~20% 或 ~60% 这样的内容时,argparse 会错误地将其当作格式化占位符来处理,并试图查找相应的参数来填充 %,但并没有提供足够的参数,导致触发了 TypeError: not enough arguments for format string 错误。

问题根源

该问题的根源在于 transformers 库(HfArgumentParser)在帮助文本中新增了带有 % 符号的内容,如 ~20%。argparse 在渲染帮助文本时会将这些百分号字符当作格式化占位符处理。由于没有额外的参数来填充这些占位符,argparse 抛出了 TypeError 错误。

具体来说,错误是由以下原因触发的:

argparse 使用旧式字符串格式化(% 格式化),它期望字符串中的占位符(如 %d, %s)被适当的参数替换。
~20% 和类似的内容在没有提供填充参数的情况下被 argparse 认为是格式化占位符,导致崩溃。
解决方案

为了解决这个问题,我们需要确保 argparse 不会把这些裸 % 符号当作格式化占位符处理。最直接的方式是使用一个自定义的帮助格式化器,这个格式化器能够识别并安全地处理带有 % 的帮助文本,避免它们触发错误。

修改了 LLaMA-Factory/src/llamafactory/hparams/parser.py

主要是加了

class _SafeHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
    """A help formatter that tolerates bare percent signs in help strings.

    Python's argparse uses old-style string formatting internally, so help text like
    "~20%" will be interpreted as a formatting placeholder and can crash `-h`.
    """

    def _expand_help(self, action: argparse.Action) -> str:  # type: ignore[override]
        params = dict(vars(action), prog=self._prog)
        help_string = self._get_help_string(action)
        try:
            return help_string % params
        except (TypeError, ValueError):
            escaped = _ESCAPE_PERCENT_PATTERN.sub("%%", help_string)
            return escaped % params

并在之后每次调用的时候加上自定义的帮助格式化器即可

    parser = HfArgumentParser(_TRAIN_ARGS, formatter_class=_SafeHelpFormatter)

完整的parser.py代码如下:


# 
Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional

import argparse
import torch
import transformers
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available

from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments


logger = logging.get_logger(__name__)

check_dependencies()


_ESCAPE_PERCENT_PATTERN = re.compile(r"%(?!\(|%)")


class _SafeHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
    """A help formatter that tolerates bare percent signs in help strings.

    Python's argparse uses old-style string formatting internally, so help text like
    "~20%" will be interpreted as a formatting placeholder and can crash `-h`.
    """

    def _expand_help(self, action: argparse.Action) -> str:  # type: ignore[override]
        params = dict(vars(action), prog=self._prog)
        help_string = self._get_help_string(action)
        try:
            return help_string % params
        except (TypeError, ValueError):
            escaped = _ESCAPE_PERCENT_PATTERN.sub("%%", help_string)
            return escaped % params


_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]

if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
    from mcore_adapter import TrainingArguments as McaTrainingArguments

    _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
    _TRAIN_MCA_CLS = tuple[
        ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
    ]
else:
    _TRAIN_MCA_ARGS = []
    _TRAIN_MCA_CLS = tuple()


def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
    r"""Get arguments from the command line or a config file."""
    if args is not None:
        return args

    if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
        override_config = OmegaConf.from_cli(sys.argv[2:])
        dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
        return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
    elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
        override_config = OmegaConf.from_cli(sys.argv[2:])
        dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
        return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
    else:
        return sys.argv[1:]


def _parse_args(
    parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
) -> tuple[Any]:
    args = read_args(args)
    if isinstance(args, dict):
        return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)

    (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)

    if unknown_args and not allow_extra_keys:
        print(parser.format_help())
        print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
        raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")

    return tuple(parsed_args)


def _verify_trackio_args(training_args: "TrainingArguments") -> None:
    """Validates Trackio-specific arguments.

    Args:
        training_args: TrainingArguments instance (not a dictionary)
    """
    report_to = training_args.report_to
    if not report_to:
        return

    if isinstance(report_to, str):
        report_to = [report_to]

    if "trackio" not in report_to:
        return

    # --- Enforce project (required by Trackio) ---
    if not training_args.project:
        raise ValueError("`--project` must be specified when using Trackio.")

    # --- Validate trackio_space_id format ---
    space_id = training_args.trackio_space_id
    if space_id:
        if space_id != "trackio" and "/" not in space_id:
            logger.warning(
                f"trackio_space_id '{space_id}' should typically be in format "
                "'org/space' for Hugging Face Spaces deployment."
            )

    # --- Inform about default project usage ---
    if training_args.project == "huggingface":
        logger.info(
            "Using default project name 'huggingface'. "
            "Consider setting a custom project name with --project "
            "for better organization."
        )

    # --- Validate hub repo privacy flag ---
    if training_args.hub_private_repo:
        logger.info("Repository will be created as private on Hugging Face Hub.")

    # --- Recommend run_name for experiment clarity ---
    if not training_args.run_name:
        logger.warning("Consider setting --run_name for better experiment tracking clarity.")


def _set_transformers_logging() -> None:
    if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()


def _set_env_vars() -> None:
    if is_torch_npu_available():
        # avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
        torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
        # avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def _verify_model_args(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    finetuning_args: "FinetuningArguments",
) -> None:
    if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
        raise ValueError("Adapter is only valid for the LoRA method.")

    if model_args.quantization_bit is not None:
        if finetuning_args.finetuning_type not in ["lora", "oft"]:
            raise ValueError("Quantization is only compatible with the LoRA or OFT method.")

        if finetuning_args.pissa_init:
            raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")

        if model_args.resize_vocab:
            raise ValueError("Cannot resize embedding layers of a quantized model.")

        if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
            raise ValueError("Cannot create new adapter upon a quantized model.")

        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
            raise ValueError("Quantized model only accepts a single adapter. Merge them first.")


def _check_extra_dependencies(
    model_args: "ModelArguments",
    finetuning_args: "FinetuningArguments",
    training_args: Optional["TrainingArguments"] = None,
) -> None:
    if model_args.use_kt:
        check_version("ktransformers", mandatory=True)

    if model_args.use_unsloth:
        check_version("unsloth", mandatory=True)

    if model_args.enable_liger_kernel:
        check_version("liger-kernel", mandatory=True)

    if model_args.mixture_of_depths is not None:
        check_version("mixture-of-depth>=1.1.6", mandatory=True)

    if model_args.infer_backend == EngineName.VLLM:
        check_version("vllm>=0.4.3,<=0.11.0")
        check_version("vllm", mandatory=True)
    elif model_args.infer_backend == EngineName.SGLANG:
        check_version("sglang>=0.4.5")
        check_version("sglang", mandatory=True)

    if finetuning_args.use_galore:
        check_version("galore_torch", mandatory=True)

    if finetuning_args.use_apollo:
        check_version("apollo_torch", mandatory=True)

    if finetuning_args.use_badam:
        check_version("badam>=1.2.1", mandatory=True)

    if finetuning_args.use_adam_mini:
        check_version("adam-mini", mandatory=True)

    if finetuning_args.use_swanlab:
        check_version("swanlab", mandatory=True)

    if finetuning_args.plot_loss:
        check_version("matplotlib", mandatory=True)

    if training_args is not None:
        if training_args.deepspeed:
            check_version("deepspeed", mandatory=True)

        if training_args.predict_with_generate:
            check_version("jieba", mandatory=True)
            check_version("nltk", mandatory=True)
            check_version("rouge_chinese", mandatory=True)


def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
    parser = HfArgumentParser(_TRAIN_ARGS, formatter_class=_SafeHelpFormatter)
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)


def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
    parser = HfArgumentParser(_TRAIN_MCA_ARGS, formatter_class=_SafeHelpFormatter)
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
        parser, args, allow_extra_keys=allow_extra_keys
    )

    _configure_mca_training_args(training_args, data_args, finetuning_args)

    return model_args, data_args, training_args, finetuning_args, generating_args


def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
    """Patch training args to avoid args checking errors and sync MCA settings."""
    training_args.predict_with_generate = False
    training_args.generation_max_length = data_args.cutoff_len
    training_args.generation_num_beams = 1
    training_args.use_mca = True
    finetuning_args.use_mca = True


def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
    parser = HfArgumentParser(_INFER_ARGS, formatter_class=_SafeHelpFormatter)
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)


def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
    parser = HfArgumentParser(_EVAL_ARGS, formatter_class=_SafeHelpFormatter)
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)


def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
    parser = HfArgumentParser(RayArguments, formatter_class=_SafeHelpFormatter)
    (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
    return ray_args


def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
    if is_env_enabled("USE_MCA"):
        model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
    else:
        model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
        finetuning_args.use_mca = False

    # Setup logging
    if training_args.should_log:
        _set_transformers_logging()

    # Check arguments
    if finetuning_args.stage != "sft":
        if training_args.predict_with_generate:
            raise ValueError("`predict_with_generate` cannot be set as True except SFT.")

        if data_args.neat_packing:
            raise ValueError("`neat_packing` cannot be set as True except SFT.")

        if data_args.train_on_prompt or data_args.mask_history:
            raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")

    if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
        raise ValueError("Please enable `predict_with_generate` to save model predictions.")

    if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
        raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")

    if finetuning_args.stage == "ppo":
        if not training_args.do_train:
            raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")

        if model_args.shift_attn:
            raise ValueError("PPO training is incompatible with S^2-Attn.")

        if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
            raise ValueError("KTransformers does not support lora reward model.")

        if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
            raise ValueError("Unsloth does not support lora reward model.")

        if training_args.report_to and any(
            logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
        ):
            raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")

    if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
        raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")

    if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
        raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")

    if training_args.max_steps == -1 and data_args.streaming:
        raise ValueError("Please specify `max_steps` in streaming mode.")

    if training_args.do_train and data_args.dataset is None:
        raise ValueError("Please specify dataset for training.")

    if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
        data_args.eval_dataset is None and data_args.val_size < 1e-6
    ):
        raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")

    if training_args.predict_with_generate:
        if is_deepspeed_zero3_enabled():
            raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")

        if finetuning_args.compute_accuracy:
            raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")

    if training_args.do_train and model_args.quantization_device_map == "auto":
        raise ValueError("Cannot use device map for quantized models in training.")

    if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
        raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")

    if finetuning_args.pure_bf16:
        if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
            raise ValueError("This device does not support `pure_bf16`.")

        if is_deepspeed_zero3_enabled():
            raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")

    if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
        if finetuning_args.use_galore and finetuning_args.galore_layerwise:
            raise ValueError("Distributed training does not support layer-wise GaLore.")

        if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
            raise ValueError("Distributed training does not support layer-wise APOLLO.")

        if finetuning_args.use_badam:
            if finetuning_args.badam_mode == "ratio":
                raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
            elif not is_deepspeed_zero3_enabled():
                raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")

    if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
        raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")

    if not finetuning_args.use_mca and training_args.fp8 and model_args.quantization_bit is not None:
        raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")

    if model_args.infer_backend != EngineName.HF:
        raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")

    if model_args.use_unsloth and is_deepspeed_zero3_enabled():
        raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")

    if model_args.use_kt and is_deepspeed_zero3_enabled():
        raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")

    _set_env_vars()
    _verify_model_args(model_args, data_args, finetuning_args)
    _check_extra_dependencies(model_args, finetuning_args, training_args)
    _verify_trackio_args(training_args)

    if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
        logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
        model_args.fp8 = True

    if (
        training_args.do_train
        and finetuning_args.finetuning_type == "lora"
        and model_args.quantization_bit is None
        and model_args.resize_vocab
        and finetuning_args.additional_target is None
    ):
        logger.warning_rank0(
            "Remember to add embedding layers to `additional_target` to make the added tokens trainable."
        )

    if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
        logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")

    if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
        logger.warning_rank0("We recommend enable mixed precision training.")

    if (
        training_args.do_train
        and (finetuning_args.use_galore or finetuning_args.use_apollo)
        and not finetuning_args.pure_bf16
    ):
        logger.warning_rank0(
            "Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
        )

    if (not training_args.do_train) and model_args.quantization_bit is not None:
        logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")

    if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
        logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")

    # Post-process training arguments
    training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
    training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
    training_args.remove_unused_columns = False  # important for multimodal dataset

    if finetuning_args.finetuning_type == "lora":
        # https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
        training_args.label_names = training_args.label_names or ["labels"]

    if "swanlab" in training_args.report_to and finetuning_args.use_swanlab:
        training_args.report_to.remove("swanlab")

    if (
        training_args.parallel_mode == ParallelMode.DISTRIBUTED
        and training_args.ddp_find_unused_parameters is None
        and finetuning_args.finetuning_type == "lora"
    ):
        logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
        training_args.ddp_find_unused_parameters = False

    if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
        can_resume_from_checkpoint = False
        if training_args.resume_from_checkpoint is not None:
            logger.warning_rank0("Cannot resume from checkpoint in current stage.")
            training_args.resume_from_checkpoint = None
    else:
        can_resume_from_checkpoint = True

    if (
        training_args.resume_from_checkpoint is None
        and training_args.do_train
        and os.path.isdir(training_args.output_dir)
        and not getattr(training_args, "overwrite_output_dir", False)  # for mca training args and transformers >= 5.0
        and can_resume_from_checkpoint
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and any(
            os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
        ):
            raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")

        if last_checkpoint is not None:
            training_args.resume_from_checkpoint = last_checkpoint
            logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
            logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")

    if (
        finetuning_args.stage in ["rm", "ppo"]
        and finetuning_args.finetuning_type == "lora"
        and training_args.resume_from_checkpoint is not None
    ):
        logger.warning_rank0(
            f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
        )

    # Post-process model arguments
    if training_args.bf16 or finetuning_args.pure_bf16:
        model_args.compute_dtype = torch.bfloat16
    elif training_args.fp16:
        model_args.compute_dtype = torch.float16

    model_args.device_map = {"": get_current_device()}
    model_args.model_max_length = data_args.cutoff_len
    model_args.block_diag_attn = data_args.neat_packing
    data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"

    # Log on each process the small summary
    logger.info(
        f"Process rank: {training_args.process_index}, "
        f"world size: {training_args.world_size}, device: {training_args.device}, "
        f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
        f"compute dtype: {str(model_args.compute_dtype)}"
    )
    transformers.set_seed(training_args.seed)

    return model_args, data_args, training_args, finetuning_args, generating_args


def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
    model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)

    # Setup logging
    _set_transformers_logging()

    # Check arguments
    if model_args.infer_backend == "vllm":
        if finetuning_args.stage != "sft":
            raise ValueError("vLLM engine only supports auto-regressive models.")

        if model_args.quantization_bit is not None:
            raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")

        if model_args.rope_scaling is not None:
            raise ValueError("vLLM engine does not support RoPE scaling.")

        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
            raise ValueError("vLLM only accepts a single adapter. Merge them first.")

    _set_env_vars()
    _verify_model_args(model_args, data_args, finetuning_args)
    _check_extra_dependencies(model_args, finetuning_args)

    # Post-process model arguments
    if model_args.export_dir is not None and model_args.export_device == "cpu":
        model_args.device_map = {"": torch.device("cpu")}
        if data_args.cutoff_len != DataArguments().cutoff_len:  # override cutoff_len if it is not default
            model_args.model_max_length = data_args.cutoff_len
    else:
        model_args.device_map = "auto"

    return model_args, data_args, finetuning_args, generating_args


def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
    model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)

    # Setup logging
    _set_transformers_logging()

    # Check arguments
    if model_args.infer_backend != EngineName.HF:
        raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")

    _set_env_vars()
    _verify_model_args(model_args, data_args, finetuning_args)
    _check_extra_dependencies(model_args, finetuning_args)

    model_args.device_map = "auto"

    transformers.set_seed(eval_args.seed)

    return model_args, data_args, eval_args, finetuning_args
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Beiwen_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值