examples/cifar10_example.rb in ruby-dnn-0.3.2 vs examples/cifar10_example.rb in ruby-dnn-0.4.0

- old
+ new

@@ -10,12 +10,12 @@ CIFAR10 = DNN::CIFAR10 x_train, y_train = CIFAR10.load_train x_test, y_test = CIFAR10.load_test -x_train = SFloat.cast(x_train).transpose(0, 2, 3, 1) -x_test = SFloat.cast(x_test).transpose(0, 2, 3, 1) +x_train = SFloat.cast(x_train) +x_test = SFloat.cast(x_test) x_train /= 255 x_test /= 255 y_train = DNN::Util.to_categorical(y_train, 10) @@ -38,9 +38,19 @@ model << Conv2D.new(32, 5, padding: true) model << BatchNormalization.new model << ReLU.new model << Conv2D.new(32, 5, padding: true) +model << BatchNormalization.new +model << ReLU.new + +model << MaxPool2D.new(2) + +model << Conv2D.new(64, 5, padding: true) +model << BatchNormalization.new +model << ReLU.new + +model << Conv2D.new(64, 5, padding: true) model << BatchNormalization.new model << ReLU.new model << Flatten.new