Skip to content

Commit a19ba5d

Browse files
author
linuxfl
committed
fix lbfgs predict bug
1 parent 5d22817 commit a19ba5d

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

lbfgs/include/lbfgs.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,14 @@ class LBFGSSolver{
251251
}
252252

253253
virtual void TaskPred(void) {
254-
num_fea = std::max(num_fea,dtrain->NumCol());
255-
Eigen::VectorXf weight = Eigen::VectorXf::Zero(num_fea);
256-
//float *weight = new float[num_fea];
257-
258254
std::ifstream is(model_in.c_str());
259255
std::ofstream os(pred_out.c_str());
260256

261257
CHECK(is.fail() == false) << "open model file error!";
258+
is >> num_fea;
259+
num_fea = std::max(num_fea,dtrain->NumCol());
260+
linear.SetParam("num_fea", std::to_string(num_fea).c_str());
261+
Eigen::VectorXf weight = Eigen::VectorXf::Zero(num_fea);
262262

263263
for(size_t i = 0;i < num_fea;i++)
264264
is >> weight[i];
@@ -290,6 +290,7 @@ class LBFGSSolver{
290290
std::ofstream os(model_out.c_str());
291291

292292
CHECK(os.fail() == false) << "open model file fail";
293+
os << num_fea << std::endl;
293294
for(size_t i = 0;i < num_fea;i++) {
294295
os << linear.old_weight[i] << std::endl;
295296
}
@@ -300,6 +301,7 @@ class LBFGSSolver{
300301
std::ifstream is(model_in.c_str());
301302

302303
CHECK(is.fail() == false) << "open model file fail";
304+
is >> num_fea;
303305
for(size_t i = 0;i < num_fea;++i)
304306
{
305307
is >> linear.old_weight[i];

lbfgs/include/linear.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ struct LinearModel {
6060
const dmlc::Row<unsigned> &v) const {
6161
float sum = 0.0;
6262
for(unsigned i = 0;i < v.length;i++) {
63-
if(v.index[i] < num_fea) {
64-
sum += w[v.index[i]];
65-
}
63+
assert(num_fea > v.index[i]);
64+
sum += w[v.index[i]];
6665
}
6766
return sum;
6867
}

0 commit comments

Comments
 (0)