examples/cifar10_example.rb in ruby-dnn-0.3.2 vs examples/cifar10_example.rb in ruby-dnn-0.4.0
- old
+ new
@@ -10,12 +10,12 @@
CIFAR10 = DNN::CIFAR10
x_train, y_train = CIFAR10.load_train
x_test, y_test = CIFAR10.load_test
-x_train = SFloat.cast(x_train).transpose(0, 2, 3, 1)
-x_test = SFloat.cast(x_test).transpose(0, 2, 3, 1)
+x_train = SFloat.cast(x_train)
+x_test = SFloat.cast(x_test)
x_train /= 255
x_test /= 255
y_train = DNN::Util.to_categorical(y_train, 10)
@@ -38,9 +38,19 @@
model << Conv2D.new(32, 5, padding: true)
model << BatchNormalization.new
model << ReLU.new
model << Conv2D.new(32, 5, padding: true)
+model << BatchNormalization.new
+model << ReLU.new
+
+model << MaxPool2D.new(2)
+
+model << Conv2D.new(64, 5, padding: true)
+model << BatchNormalization.new
+model << ReLU.new
+
+model << Conv2D.new(64, 5, padding: true)
model << BatchNormalization.new
model << ReLU.new
model << Flatten.new