Sha256: c0be1fd2307e885580916e616e33ef0c716bbee625b57da031523a44a45f3d84

Contents?: true

Size: 1.41 KB

Versions: 7

Compression:

Stored size: 1.41 KB

Contents

#include <ruby.h>
#include <stdlib.h>
#include "numo/narray.h"

#define CIFAR10_WIDTH 32
#define CIFAR10_HEIGHT 32
#define CIFAR10_CHANNEL 3
#define CIFAR10_CLASSES 10

VALUE cifar10_load(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
  char* bin = StringValuePtr(rb_bin);
  int num_datas = FIX2INT(rb_num_datas);
  char script[64];
  VALUE rb_na_x;
  VALUE rb_na_y;
  VALUE rb_xy;
  narray_data_t* na_data_x;
  narray_data_t* na_data_y;
  int i;
  int j = 0;
  int k = 0;
  int size = CIFAR10_WIDTH * CIFAR10_HEIGHT * CIFAR10_CHANNEL;

  sprintf(script, "Numo::UInt8.zeros(%d, %d, %d, %d)", num_datas, CIFAR10_WIDTH, CIFAR10_HEIGHT, CIFAR10_CHANNEL);
  rb_na_x = rb_eval_string(&script[0]);
  na_data_x = RNARRAY_DATA(rb_na_x);
  for(i = 0; i < 64; i++) {
    script[i] = 0;
  }
  sprintf(script, "Numo::UInt8.zeros(%d, %d)", num_datas, CIFAR10_CLASSES);
  rb_na_y = rb_eval_string(&script[0]);
  na_data_y = RNARRAY_DATA(rb_na_y);

  for (i = 0; i < num_datas; i++) {
    na_data_y->ptr[i] = bin[j];
    j++;
    memcpy(&na_data_x->ptr[k], &bin[j], size);
    j += size;
    k += size;
  }

  rb_xy = rb_ary_new3(2, rb_na_x, rb_na_y);
  return rb_xy;
}

void Init_cifar10_ext() {
  VALUE rb_dnn;
  VALUE rb_cifar10;
  rb_dnn = rb_define_module("DNN");
  rb_cifar10 = rb_define_module_under(rb_dnn, "CIFAR10");
  rb_define_singleton_method(rb_cifar10, "_cifar10_load", cifar10_load, 2);
}

Version data entries

7 entries across 7 versions & 1 rubygems

Version Path
ruby-dnn-0.1.6 lib/dnn/ext/cifar10/cifar10_ext.c
ruby-dnn-0.1.5 lib/dnn/ext/cifar10/cifar10_ext.c
ruby-dnn-0.1.4 lib/dnn/ext/cifar10/cifar10_ext.c
ruby-dnn-0.1.3 lib/dnn/ext/cifar10/cifar10_ext.c
ruby-dnn-0.1.2 lib/dnn/ext/cifar10/cifar10_ext.c
ruby-dnn-0.1.1 lib/dnn/ext/cifar10/cifar10_ext.c
ruby-dnn-0.1.0 lib/dnn/ext/cifar10/cifar10_ext.c