LLaVA OV训练相关代码梳理
LlaVA-OV 多模态大模型训练过程说明:
docs/LLaVA_OneVision.md
- Stage-1: Initial training on 558K samples from the LCS dataset
脚本:scripts/train/pretrain_siglip.sh(旧版: scripts/train/pretrain_clip.sh) 的核心内容:
llava/train/train_mem.py --deepspeed scripts/zero3.json --model_name_or_path Qwen/Qwen2-7B-Instruct --vision_tower google/siglip-so400m-patch14-384 --data_path /blip_558k/blip_558k_plain.json
部分训练数据: HF1, LLaVA-OneVision Mid-Stage Data。 - Stage-1.5: Training on 4M high-quality samples with detailed captions, OCR and knowledge data.
TODO: 暂时没有相关资料来源,欢迎在评论区指正。 - Stage-2:
○ Single-Image: Training on 3.2M instruction-following image samples.
脚本: scripts/train/finetune_si.sh 的核心内容:
llava/train/train_mem.py --mm_tunable_parts=“mm_vision_tower,mm_mlp_adapter,mm_language_model”
–image_aspect_ratio anyres_max_9 --image_grid_pinpoints “(1x1),…,(6x6)”
–mm_patch_merge_type spatial_unpad
–video_folder /mnt/bn/vl-research/data/llava_video --frames_upbound 32
–data_path /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_3p2m_single_image.yaml
数据:onevision-data
相关配置: scripts/train/single_image.yaml
○ OneVision: Training on 1.6M single-image, multi-image and video samples with instructions.
脚本: scripts/train/finetune_ov.sh 与上面的类似。
数据:Multi-image data is released in M4-Instruct Data. video part along with llava-video-data(其中的一部分).
相关配置:scripts/train/onevision.yaml。
训练代码
llava/train/train_mem.py & train/train.py
# train/train.py 的实现:
def train(attn_implementation=None):
...
model = get_model(model_args, training_args, bnb_model_from_pretrained_args) # 举例 llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_3m_am9_july14 加载调用的是 model = LlavaQwenForCausalLM.from_pretrained(xxx)
...
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
2731

被折叠的 条评论
为什么被折叠?



