ext/numo/narray/narray.c in numo-narray-0.9.0.5 vs ext/numo/narray/narray.c in numo-narray-0.9.0.6

- old
+ new

@@ -30,10 +30,11 @@ static ID id_shift_left; static ID id_eq; static ID id_count_false; static ID id_axis; static ID id_nan; +static ID id_keepdims; VALUE cPointer; VALUE sym_reduce; VALUE sym_option; @@ -1040,11 +1041,11 @@ narray_t *na; narray_view_t *na1, *na2; VALUE view; VALUE reduce; - reduce = na_reduce_dimension(argc, argv, 1, &self, 0); + reduce = na_reduce_dimension(argc, argv, 1, &self, 0, 0); GetNArray(self,na); nd = na->ndim; view = na_s_allocate_view(CLASS_OF(self)); @@ -1446,11 +1447,12 @@ } } VALUE -na_reduce_dimension(int argc, VALUE *argv, int naryc, VALUE *naryv, int *propagate_nan) +na_reduce_dimension(int argc, VALUE *argv, int naryc, VALUE *naryv, + ndfunc_t *ndf, na_iter_func_t iter_nan) { int ndim, ndim0; int row_major; int i, r; long narg; @@ -1460,16 +1462,16 @@ VALUE v; narray_t *na; size_t m; VALUE reduce; VALUE kw_hash = Qnil; - ID kw_table[2] = {id_axis,id_nan}; - VALUE opts[2] = {Qundef,Qundef}; + ID kw_table[3] = {id_axis,id_nan,id_keepdims}; + VALUE opts[3] = {Qundef,Qundef,Qundef}; VALUE axes; narg = rb_scan_args(argc, argv, "*:", &axes, &kw_hash); - rb_get_kwargs(kw_hash, kw_table, 0, 2, opts); + rb_get_kwargs(kw_hash, kw_table, 0, 3, opts); // option: axis if (opts[0] != Qundef && RTEST(opts[0])) { if (narg > 0) { rb_raise(rb_eArgError,"both axis-arguments and axis-keyword are given"); @@ -1478,13 +1480,21 @@ axes = opts[0]; } else { axes = rb_ary_new3(1,opts[0]); } } - // option: ignore_none - if (propagate_nan) { - *propagate_nan = (opts[1] != Qundef && RTEST(opts[1])) ? 1 : 0; + if (ndf) { + // option: nan + if (iter_nan && opts[1] != Qundef) { + if (RTEST(opts[1])) + ndf->func = iter_nan; // replace to nan-aware iterator function + } + // option: keepdims + if (opts[2] != Qundef) { + if (RTEST(opts[2])) + ndf->flag |= NDF_KEEP_DIM; + } } if (naryc<1) { rb_raise(rb_eRuntimeError,"must be positive: naryc=%d", naryc); } @@ -1888,9 +1898,10 @@ id_shift_left = rb_intern("<<"); id_eq = rb_intern("eq"); id_count_false = rb_intern("count_false"); id_axis = rb_intern("axis"); id_nan = rb_intern("nan"); + id_keepdims = rb_intern("keepdims"); sym_reduce = ID2SYM(rb_intern("reduce")); sym_option = ID2SYM(rb_intern("option")); sym_loop_opt = ID2SYM(rb_intern("loop_opt")); sym_init = ID2SYM(rb_intern("init"));