ext/lbfgsb/lbfgsbext.c in lbfgsb-0.1.0 vs ext/lbfgsb/lbfgsbext.c in lbfgsb-0.2.0
- old
+ new
@@ -35,10 +35,11 @@
long lsave[4];
long isave[44];
double dsave[29];
double* g_ptr;
VALUE g_val;
+ VALUE fg_arr;
VALUE ret;
GetNArray(x_val, x_nary);
if (NA_NDIM(x_nary) != 1) {
rb_raise(rb_eArgError, "x must be a 1-D array.");
@@ -118,11 +119,17 @@
setulb_(
&n, &m, x_ptr, l_ptr, u_ptr, nbd_ptr, &f, g, &factr, &pgtol, wa, iwa,
task, &iprint, csave, lsave, isave, dsave
);
if (strncmp(task, "FG", 2) == 0) {
- f = NUM2DBL(rb_funcall(self, rb_intern("fnc"), 3, fnc, x_val, args));
- g_val = rb_funcall(self, rb_intern("jcb"), 3, jcb, x_val, args);
+ if (RB_TYPE_P(jcb, T_TRUE)) {
+ fg_arr = rb_funcall(self, rb_intern("fnc"), 3, fnc, x_val, args);
+ f = NUM2DBL(rb_ary_entry(fg_arr, 0));
+ g_val = rb_ary_entry(fg_arr, 1);
+ } else {
+ f = NUM2DBL(rb_funcall(self, rb_intern("fnc"), 3, fnc, x_val, args));
+ g_val = rb_funcall(self, rb_intern("jcb"), 3, jcb, x_val, args);
+ }
n_fev += 1;
n_jev += 1;
if (CLASS_OF(g_val) != numo_cDFloat) g_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, g_val);
if (!RTEST(nary_check_contiguous(g_val))) g_val = nary_dup(g_val);
g_ptr = (double*)na_get_pointer_for_read(g_val);