手把手教你运行Deep_Metric:基于run_train_00.sh的训练流程全解析

手把手教你运行Deep_Metric:基于run_train_00.sh的训练流程全解析

【免费下载链接】Deep_Metric Deep Metric Learning 【免费下载链接】Deep_Metric 项目地址: https://gitcode.com/gh_mirrors/de/Deep_Metric

Deep_Metric是一个专注于深度度量学习(Deep Metric Learning)的开源项目,通过run_train_00.sh脚本可以快速启动模型训练与评估流程。本文将详细解析该脚本的使用方法,帮助新手用户轻松上手深度度量学习模型的训练。

📋 准备工作:环境与依赖检查

在运行训练脚本前,需确保系统已安装以下依赖:

  • Python 3.x
  • PyTorch 1.0+
  • CUDA环境(建议8.0以上)
  • 相关Python库:numpyscipytorchvision

项目核心代码结构如下:

🔧 脚本参数解析:定制你的训练任务

run_train_00.sh是整个训练流程的入口脚本,核心参数配置如下:

1. 基础配置参数

DATA=cub                  # 数据集名称(支持cub/car196等)
DATA_ROOT=data            # 数据存放路径
LOSS=Weight               # 损失函数类型(Weight/Contrastive等)
CHECKPOINTS=ckps          # 模型 checkpoint 保存目录
NET=BN-Inception          # 网络架构
DIM=512                   # 特征维度
BATCH_SIZE=80             # 批次大小

2. 训练超参数

LR=1e-5                   # 学习率
ALPHA=40                  # 损失函数超参数
RATIO=0.16                # 训练数据随机裁剪比例
EPOCH=600                 # 训练轮数
SAVE_STEP=50              # 模型保存间隔(每50轮保存一次)

所有参数在脚本中均有明确定义,用户可根据硬件条件和任务需求调整。例如显存较小的情况下,建议减小BATCH_SIZE至32或16。

🚀 训练流程详解:从启动到评估

1. 一键启动训练

在项目根目录下执行以下命令启动训练:

bash run_train_00.sh

脚本会自动完成以下操作:

  • 创建必要的目录结构(checkpoints/result等)
  • 设置CUDA可见设备(默认使用GPU 0)
  • 调用train.py开始训练流程

2. 训练过程解析

训练核心逻辑在train.py中实现,主要流程包括:

  1. 模型初始化:加载预训练的BN-Inception模型
  2. 数据加载:使用DataSet/中的数据集类加载训练数据
  3. 优化器配置:采用Adam优化器,对不同层设置不同学习率
  4. 损失计算:根据配置加载对应的损失函数(如Weighted Triplet Loss)
  5. 模型保存:按间隔保存训练好的模型参数

关键代码片段(train.py第38-52行):

CUDA_VISIBLE_DEVICES=0 python train.py --net ${NET} \
--data $DATA \
--data_root ${DATA_ROOT} \
--init random \
--lr $LR \
--dim $DIM \
--alpha $ALPHA \
--num_instances   5 \
--batch_size ${BatchSize} \
--epoch 600 \
--loss $LOSS \
--width 227 \
--save_dir ${SAVE_DIR} \
--save_step 50 \
--ratio ${RATIO}

3. 自动评估流程

训练完成后,脚本会自动触发评估流程:

  1. 加载每50轮保存的模型(50/100/.../600轮)
  2. 调用test.py计算Recall@K指标
  3. 将结果保存至result目录下的日志文件

评估核心逻辑(test.py第36-44行):

gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint,
                   dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

sim_mat = pairwise_similarity(query_feature, gallery_feature)
recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data)

📊 结果查看与分析

训练和评估结果会保存在以下路径:

  • 模型文件:ckps/Weight/cub/
  • 评估日志:result/Weight/cub/

日志文件格式示例:

Epoch-50 0.7820 0.8910 0.9450 0.9680
Epoch-100 0.8010 0.9050 0.9520 0.9720

其中数字依次表示Recall@1、Recall@2、Recall@4、Recall@8指标。

❗ 常见问题解决

1. 显存不足

  • 解决方案:减小BATCH_SIZE参数,如改为40或20
  • 位置:run_train_00.sh第29行

2. 数据集路径错误

  • 解决方案:修改DATA_ROOT参数指向正确的数据目录
  • 位置:run_train_00.sh第3行

3. 评估指标为0

  • 检查数据是否正确加载
  • 确认gallery_eq_query参数设置是否正确(默认为True)

📝 总结

通过run_train_00.sh脚本,Deep_Metric实现了训练流程的高度自动化。用户只需简单配置参数,即可完成从数据加载、模型训练到性能评估的全流程。项目模块化的设计也使得扩展新的数据集或损失函数变得简单。

如需进一步定制训练流程,可以修改以下核心文件:

希望本文能帮助你快速掌握Deep_Metric的训练流程,开始你的深度度量学习之旅!

【免费下载链接】Deep_Metric Deep Metric Learning 【免费下载链接】Deep_Metric 项目地址: https://gitcode.com/gh_mirrors/de/Deep_Metric

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值