ext/numo/narray/narray.c in numo-narray-0.9.1.5 vs ext/numo/narray/narray.c in numo-narray-0.9.1.6

- old
+ new

@@ -32,10 +32,11 @@ static ID id_eq; static ID id_count_false; static ID id_axis; static ID id_nan; static ID id_keepdims; +static ID id_source; VALUE cPointer; VALUE sym_reduce; VALUE sym_option; @@ -603,10 +604,91 @@ #define READ 1 #define WRITE 2 +static void +na_set_pointer(VALUE self, char *ptr, size_t byte_size) +{ + VALUE obj; + narray_t *na; + + if (OBJ_FROZEN(self)) { + rb_raise(rb_eRuntimeError, "cannot write to frozen NArray."); + } + + GetNArray(self,na); + + switch(NA_TYPE(na)) { + case NARRAY_DATA_T: + if (NA_SIZE(na) > 0) { + if (NA_DATA_PTR(na) != NULL && NA_DATA_OWNED(na)) { + xfree(NA_DATA_PTR(na)); + } + NA_DATA_PTR(na) = ptr; + NA_DATA_OWNED(na) = FALSE; + } + return; + case NARRAY_VIEW_T: + obj = NA_VIEW_DATA(na); + if (OBJ_FROZEN(obj)) { + rb_raise(rb_eRuntimeError, "cannot write to frozen NArray."); + } + GetNArray(obj,na); + switch(NA_TYPE(na)) { + case NARRAY_DATA_T: + if (NA_SIZE(na) > 0) { + if (NA_DATA_PTR(na) != NULL && NA_DATA_OWNED(na)) { + xfree(NA_DATA_PTR(na)); + } + NA_DATA_PTR(na) = ptr; + NA_DATA_OWNED(na) = FALSE; + } + return; + default: + rb_raise(rb_eRuntimeError,"invalid NA_TYPE of view: %d",NA_TYPE(na)); + } + default: + rb_raise(rb_eRuntimeError,"invalid NA_TYPE: %d",NA_TYPE(na)); + } +} + +static void +na_pointer_copy_on_write(VALUE self) +{ + narray_t *na; + void *ptr; + VALUE velmsz; + size_t byte_size; + + GetNArray(self,na); + if (NA_TYPE(na) == NARRAY_VIEW_T) { + self = NA_VIEW_DATA(na); + GetNArray(self,na); + } + + ptr = NA_DATA_PTR(na); + if (ptr == NULL) { + return; + } + + if (NA_DATA_OWNED(na)) { + return; + } + + velmsz = rb_const_get(rb_obj_class(self), id_element_byte_size); + if (FIXNUM_P(velmsz)) { + byte_size = NA_SIZE(na) * NUM2SIZET(velmsz); + } else { + byte_size = ceil(NA_SIZE(na) * NUM2DBL(velmsz)); + } + NA_DATA_PTR(na) = NULL; + rb_funcall(self, id_allocate, 0); + memcpy(NA_DATA_PTR(na), ptr, byte_size); + rb_ivar_set(self, id_source, Qnil); +} + static char * na_get_pointer_for_rw(VALUE self, int flag) { char *ptr; VALUE obj; @@ -618,10 +700,13 @@ GetNArray(self,na); switch(NA_TYPE(na)) { case NARRAY_DATA_T: + if (flag & WRITE) { + na_pointer_copy_on_write(self); + } ptr = NA_DATA_PTR(na); if (NA_SIZE(na) > 0 && ptr == NULL) { if (flag & READ) { rb_raise(rb_eRuntimeError,"cannot read unallocated NArray"); } @@ -634,10 +719,13 @@ case NARRAY_VIEW_T: obj = NA_VIEW_DATA(na); if ((flag & WRITE) && OBJ_FROZEN(obj)) { rb_raise(rb_eRuntimeError, "cannot write to frozen NArray."); } + if (flag & WRITE) { + na_pointer_copy_on_write(self); + } GetNArray(obj,na); switch(NA_TYPE(na)) { case NARRAY_DATA_T: ptr = NA_DATA_PTR(na); if (flag & (READ|WRITE)) { @@ -1258,11 +1346,10 @@ static VALUE nary_s_from_binary(int argc, VALUE *argv, VALUE type) { size_t len, str_len, byte_size; size_t *shape; - char *ptr; int i, nd, narg; VALUE vstr, vshape, vna; VALUE velmsz; narg = rb_scan_args(argc,argv,"11",&vstr,&vshape); @@ -1313,14 +1400,18 @@ shape = ALLOCA_N(size_t,nd); shape[0] = len; } vna = nary_new(type, nd, shape); - ptr = na_get_pointer_for_write(vna); + if (OBJ_FROZEN(vstr)) { + na_set_pointer(vna, RSTRING_PTR(vstr), byte_size); + rb_ivar_set(vna, id_source, vstr); + } else { + void *ptr = na_get_pointer_for_write(vna); + memcpy(ptr, RSTRING_PTR(vstr), byte_size); + } - memcpy(ptr, RSTRING_PTR(vstr), byte_size); - return vna; } /* Returns a new 1-D array initialized from binary raw data in a string. @@ -1331,11 +1422,10 @@ */ static VALUE nary_store_binary(int argc, VALUE *argv, VALUE self) { size_t size, str_len, byte_size, offset; - char *ptr; int narg; VALUE vstr, voffset; VALUE velmsz; narray_t *na; @@ -1361,12 +1451,17 @@ } if (byte_size > str_len) { rb_raise(rb_eArgError, "string is too short to store"); } - ptr = na_get_pointer_for_write(self); - memcpy(ptr, RSTRING_PTR(vstr)+offset, byte_size); + if (OBJ_FROZEN(vstr)) { + na_set_pointer(self, RSTRING_PTR(vstr)+offset, byte_size); + rb_ivar_set(self, id_source, vstr); + } else { + void *ptr = na_get_pointer_for_write(self); + memcpy(ptr, RSTRING_PTR(vstr)+offset, byte_size); + } return SIZET2NUM(byte_size); } /* @@ -1468,10 +1563,11 @@ rb_raise(rb_eArgError,"RObject content size mismatch"); } ptr = na_get_pointer_for_write(self); memcpy(ptr, RARRAY_PTR(v), NA_SIZE(na)*sizeof(VALUE)); } else { + rb_str_freeze(v); nary_store_binary(1,&v,self); if (TEST_BYTE_SWAPPED(self)) { rb_funcall(na_inplace(self),id_to_host,0); REVERSE_ENDIAN(self); // correct behavior?? } @@ -2007,9 +2103,10 @@ id_eq = rb_intern("eq"); id_count_false = rb_intern("count_false"); id_axis = rb_intern("axis"); id_nan = rb_intern("nan"); id_keepdims = rb_intern("keepdims"); + id_source = rb_intern("source"); sym_reduce = ID2SYM(rb_intern("reduce")); sym_option = ID2SYM(rb_intern("option")); sym_loop_opt = ID2SYM(rb_intern("loop_opt")); sym_init = ID2SYM(rb_intern("init"));