ext/libsvm/svm.cpp in rb-libsvm-1.0.8 vs ext/libsvm/svm.cpp in rb-libsvm-1.0.9

- old
+ new

@@ -735,11 +735,11 @@ // reconstruct the whole gradient to calculate objective value reconstruct_gradient(); active_size = l; info("*"); } - info("\nWARNING: reaching max number of iterations"); + fprintf(stderr,"\nWARNING: reaching max number of iterations\n"); } // calculate rho si->rho = calculate_rho(); @@ -1004,11 +1004,11 @@ // // Solver for nu-svm classification and regression // // additional constraint: e^T \alpha = constant // -class Solver_NU : public Solver +class Solver_NU: public Solver { public: Solver_NU() {} void Solve(int l, const QMatrix& Q, const double *p, const schar *y, double *alpha, double Cp, double Cn, double eps, @@ -2105,16 +2105,18 @@ for(i=0;i<prob->l;i++) if(fabs(f.alpha[i]) > 0) ++nSV; model->l = nSV; model->SV = Malloc(svm_node *,nSV); model->sv_coef[0] = Malloc(double,nSV); + model->sv_indices = Malloc(int,nSV); int j = 0; for(i=0;i<prob->l;i++) if(fabs(f.alpha[i]) > 0) { model->SV[j] = prob->x[i]; model->sv_coef[0][j] = f.alpha[i]; + model->sv_indices[j] = i+1; ++j; } free(f.alpha); } @@ -2252,13 +2254,18 @@ info("Total nSV = %d\n",total_sv); model->l = total_sv; model->SV = Malloc(svm_node *,total_sv); + model->sv_indices = Malloc(int,total_sv); p = 0; for(i=0;i<l;i++) - if(nonzero[i]) model->SV[p++] = x[i]; + if(nonzero[i]) + { + model->SV[p] = x[i]; + model->sv_indices[p++] = perm[i] + 1; + } int *nz_start = Malloc(int,nr_class); nz_start[0] = 0; for(i=1;i<nr_class;i++) nz_start[i] = nz_start[i-1]+nz_count[i-1]; @@ -2440,9 +2447,21 @@ void svm_get_labels(const svm_model *model, int* label) { if (model->label != NULL) for(int i=0;i<model->nr_class;i++) label[i] = model->label[i]; +} + +void svm_get_sv_indices(const svm_model *model, int* indices) +{ + if (model->sv_indices != NULL) + for(int i=0;i<model->l;i++) + indices[i] = model->sv_indices[i]; +} + +int svm_get_nr_sv(const svm_model *model) +{ + return model->l; } double svm_get_svr_probability(const svm_model *model) { if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&