ext/numo/narray/narray.c in numo-narray-0.9.1.5 vs ext/numo/narray/narray.c in numo-narray-0.9.1.6
- old
+ new
@@ -32,10 +32,11 @@
static ID id_eq;
static ID id_count_false;
static ID id_axis;
static ID id_nan;
static ID id_keepdims;
+static ID id_source;
VALUE cPointer;
VALUE sym_reduce;
VALUE sym_option;
@@ -603,10 +604,91 @@
#define READ 1
#define WRITE 2
+static void
+na_set_pointer(VALUE self, char *ptr, size_t byte_size)
+{
+ VALUE obj;
+ narray_t *na;
+
+ if (OBJ_FROZEN(self)) {
+ rb_raise(rb_eRuntimeError, "cannot write to frozen NArray.");
+ }
+
+ GetNArray(self,na);
+
+ switch(NA_TYPE(na)) {
+ case NARRAY_DATA_T:
+ if (NA_SIZE(na) > 0) {
+ if (NA_DATA_PTR(na) != NULL && NA_DATA_OWNED(na)) {
+ xfree(NA_DATA_PTR(na));
+ }
+ NA_DATA_PTR(na) = ptr;
+ NA_DATA_OWNED(na) = FALSE;
+ }
+ return;
+ case NARRAY_VIEW_T:
+ obj = NA_VIEW_DATA(na);
+ if (OBJ_FROZEN(obj)) {
+ rb_raise(rb_eRuntimeError, "cannot write to frozen NArray.");
+ }
+ GetNArray(obj,na);
+ switch(NA_TYPE(na)) {
+ case NARRAY_DATA_T:
+ if (NA_SIZE(na) > 0) {
+ if (NA_DATA_PTR(na) != NULL && NA_DATA_OWNED(na)) {
+ xfree(NA_DATA_PTR(na));
+ }
+ NA_DATA_PTR(na) = ptr;
+ NA_DATA_OWNED(na) = FALSE;
+ }
+ return;
+ default:
+ rb_raise(rb_eRuntimeError,"invalid NA_TYPE of view: %d",NA_TYPE(na));
+ }
+ default:
+ rb_raise(rb_eRuntimeError,"invalid NA_TYPE: %d",NA_TYPE(na));
+ }
+}
+
+static void
+na_pointer_copy_on_write(VALUE self)
+{
+ narray_t *na;
+ void *ptr;
+ VALUE velmsz;
+ size_t byte_size;
+
+ GetNArray(self,na);
+ if (NA_TYPE(na) == NARRAY_VIEW_T) {
+ self = NA_VIEW_DATA(na);
+ GetNArray(self,na);
+ }
+
+ ptr = NA_DATA_PTR(na);
+ if (ptr == NULL) {
+ return;
+ }
+
+ if (NA_DATA_OWNED(na)) {
+ return;
+ }
+
+ velmsz = rb_const_get(rb_obj_class(self), id_element_byte_size);
+ if (FIXNUM_P(velmsz)) {
+ byte_size = NA_SIZE(na) * NUM2SIZET(velmsz);
+ } else {
+ byte_size = ceil(NA_SIZE(na) * NUM2DBL(velmsz));
+ }
+ NA_DATA_PTR(na) = NULL;
+ rb_funcall(self, id_allocate, 0);
+ memcpy(NA_DATA_PTR(na), ptr, byte_size);
+ rb_ivar_set(self, id_source, Qnil);
+}
+
static char *
na_get_pointer_for_rw(VALUE self, int flag)
{
char *ptr;
VALUE obj;
@@ -618,10 +700,13 @@
GetNArray(self,na);
switch(NA_TYPE(na)) {
case NARRAY_DATA_T:
+ if (flag & WRITE) {
+ na_pointer_copy_on_write(self);
+ }
ptr = NA_DATA_PTR(na);
if (NA_SIZE(na) > 0 && ptr == NULL) {
if (flag & READ) {
rb_raise(rb_eRuntimeError,"cannot read unallocated NArray");
}
@@ -634,10 +719,13 @@
case NARRAY_VIEW_T:
obj = NA_VIEW_DATA(na);
if ((flag & WRITE) && OBJ_FROZEN(obj)) {
rb_raise(rb_eRuntimeError, "cannot write to frozen NArray.");
}
+ if (flag & WRITE) {
+ na_pointer_copy_on_write(self);
+ }
GetNArray(obj,na);
switch(NA_TYPE(na)) {
case NARRAY_DATA_T:
ptr = NA_DATA_PTR(na);
if (flag & (READ|WRITE)) {
@@ -1258,11 +1346,10 @@
static VALUE
nary_s_from_binary(int argc, VALUE *argv, VALUE type)
{
size_t len, str_len, byte_size;
size_t *shape;
- char *ptr;
int i, nd, narg;
VALUE vstr, vshape, vna;
VALUE velmsz;
narg = rb_scan_args(argc,argv,"11",&vstr,&vshape);
@@ -1313,14 +1400,18 @@
shape = ALLOCA_N(size_t,nd);
shape[0] = len;
}
vna = nary_new(type, nd, shape);
- ptr = na_get_pointer_for_write(vna);
+ if (OBJ_FROZEN(vstr)) {
+ na_set_pointer(vna, RSTRING_PTR(vstr), byte_size);
+ rb_ivar_set(vna, id_source, vstr);
+ } else {
+ void *ptr = na_get_pointer_for_write(vna);
+ memcpy(ptr, RSTRING_PTR(vstr), byte_size);
+ }
- memcpy(ptr, RSTRING_PTR(vstr), byte_size);
-
return vna;
}
/*
Returns a new 1-D array initialized from binary raw data in a string.
@@ -1331,11 +1422,10 @@
*/
static VALUE
nary_store_binary(int argc, VALUE *argv, VALUE self)
{
size_t size, str_len, byte_size, offset;
- char *ptr;
int narg;
VALUE vstr, voffset;
VALUE velmsz;
narray_t *na;
@@ -1361,12 +1451,17 @@
}
if (byte_size > str_len) {
rb_raise(rb_eArgError, "string is too short to store");
}
- ptr = na_get_pointer_for_write(self);
- memcpy(ptr, RSTRING_PTR(vstr)+offset, byte_size);
+ if (OBJ_FROZEN(vstr)) {
+ na_set_pointer(self, RSTRING_PTR(vstr)+offset, byte_size);
+ rb_ivar_set(self, id_source, vstr);
+ } else {
+ void *ptr = na_get_pointer_for_write(self);
+ memcpy(ptr, RSTRING_PTR(vstr)+offset, byte_size);
+ }
return SIZET2NUM(byte_size);
}
/*
@@ -1468,10 +1563,11 @@
rb_raise(rb_eArgError,"RObject content size mismatch");
}
ptr = na_get_pointer_for_write(self);
memcpy(ptr, RARRAY_PTR(v), NA_SIZE(na)*sizeof(VALUE));
} else {
+ rb_str_freeze(v);
nary_store_binary(1,&v,self);
if (TEST_BYTE_SWAPPED(self)) {
rb_funcall(na_inplace(self),id_to_host,0);
REVERSE_ENDIAN(self); // correct behavior??
}
@@ -2007,9 +2103,10 @@
id_eq = rb_intern("eq");
id_count_false = rb_intern("count_false");
id_axis = rb_intern("axis");
id_nan = rb_intern("nan");
id_keepdims = rb_intern("keepdims");
+ id_source = rb_intern("source");
sym_reduce = ID2SYM(rb_intern("reduce"));
sym_option = ID2SYM(rb_intern("option"));
sym_loop_opt = ID2SYM(rb_intern("loop_opt"));
sym_init = ID2SYM(rb_intern("init"));