ext/numo/narray/narray.c in numo-narray-0.9.0.5 vs ext/numo/narray/narray.c in numo-narray-0.9.0.6
- old
+ new
@@ -30,10 +30,11 @@
static ID id_shift_left;
static ID id_eq;
static ID id_count_false;
static ID id_axis;
static ID id_nan;
+static ID id_keepdims;
VALUE cPointer;
VALUE sym_reduce;
VALUE sym_option;
@@ -1040,11 +1041,11 @@
narray_t *na;
narray_view_t *na1, *na2;
VALUE view;
VALUE reduce;
- reduce = na_reduce_dimension(argc, argv, 1, &self, 0);
+ reduce = na_reduce_dimension(argc, argv, 1, &self, 0, 0);
GetNArray(self,na);
nd = na->ndim;
view = na_s_allocate_view(CLASS_OF(self));
@@ -1446,11 +1447,12 @@
}
}
VALUE
-na_reduce_dimension(int argc, VALUE *argv, int naryc, VALUE *naryv, int *propagate_nan)
+na_reduce_dimension(int argc, VALUE *argv, int naryc, VALUE *naryv,
+ ndfunc_t *ndf, na_iter_func_t iter_nan)
{
int ndim, ndim0;
int row_major;
int i, r;
long narg;
@@ -1460,16 +1462,16 @@
VALUE v;
narray_t *na;
size_t m;
VALUE reduce;
VALUE kw_hash = Qnil;
- ID kw_table[2] = {id_axis,id_nan};
- VALUE opts[2] = {Qundef,Qundef};
+ ID kw_table[3] = {id_axis,id_nan,id_keepdims};
+ VALUE opts[3] = {Qundef,Qundef,Qundef};
VALUE axes;
narg = rb_scan_args(argc, argv, "*:", &axes, &kw_hash);
- rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
+ rb_get_kwargs(kw_hash, kw_table, 0, 3, opts);
// option: axis
if (opts[0] != Qundef && RTEST(opts[0])) {
if (narg > 0) {
rb_raise(rb_eArgError,"both axis-arguments and axis-keyword are given");
@@ -1478,13 +1480,21 @@
axes = opts[0];
} else {
axes = rb_ary_new3(1,opts[0]);
}
}
- // option: ignore_none
- if (propagate_nan) {
- *propagate_nan = (opts[1] != Qundef && RTEST(opts[1])) ? 1 : 0;
+ if (ndf) {
+ // option: nan
+ if (iter_nan && opts[1] != Qundef) {
+ if (RTEST(opts[1]))
+ ndf->func = iter_nan; // replace to nan-aware iterator function
+ }
+ // option: keepdims
+ if (opts[2] != Qundef) {
+ if (RTEST(opts[2]))
+ ndf->flag |= NDF_KEEP_DIM;
+ }
}
if (naryc<1) {
rb_raise(rb_eRuntimeError,"must be positive: naryc=%d", naryc);
}
@@ -1888,9 +1898,10 @@
id_shift_left = rb_intern("<<");
id_eq = rb_intern("eq");
id_count_false = rb_intern("count_false");
id_axis = rb_intern("axis");
id_nan = rb_intern("nan");
+ id_keepdims = rb_intern("keepdims");
sym_reduce = ID2SYM(rb_intern("reduce"));
sym_option = ID2SYM(rb_intern("option"));
sym_loop_opt = ID2SYM(rb_intern("loop_opt"));
sym_init = ID2SYM(rb_intern("init"));