lib/dnn/keras-model-convertor.rb in ruby-dnn-1.1.0 vs lib/dnn/keras-model-convertor.rb in ruby-dnn-1.1.1
- old
+ new
@@ -13,10 +13,20 @@
pyimport :keras
pyfrom :"keras.models", import: :Sequential
pyfrom :"keras.layers", import: [:Dense, :Dropout, :Conv2D, :Activation, :MaxPooling2D, :Flatten]
pyfrom :"keras.layers.normalization", import: :BatchNormalization
+module DNN
+ module Layers
+ class Softmax < Layer
+ def forward(x)
+ Exp.(x) / Sum.(Exp.(x), axis: 1)
+ end
+ end
+ end
+end
+
class DNNKerasModelConvertError < DNN::DNNError; end
class KerasModelConvertor
pyfrom :"keras.models", import: :load_model
@@ -32,22 +42,25 @@
def convert
unless @k_model.__class__.__name__ == "Sequential"
raise DNNKerasModelConvertError.new("#{@k_model.__class__.__name__} models do not support convert.")
end
- dnn_model = DNN::Models::Sequential.new
- @k_model.layers.each do |k_layer|
- dnn_layer = layer_convert(k_layer)
- dnn_model << dnn_layer if dnn_layer
- end
+ layers = convert_layers(@k_model.layers)
input_shape = @k_model.layers[0].input_shape.to_a[1..-1]
input_layer = DNN::Layers::InputLayer.new(input_shape)
input_layer.build(input_shape)
- dnn_model.insert(0, input_layer)
+ layers.unshift(input_layer)
+ dnn_model = DNN::Models::Sequential.new(layers)
dnn_model
end
+ def convert_layers(k_layers)
+ k_layers.map do |k_layer|
+ layer_convert(k_layer)
+ end
+ end
+
private
def layer_convert(k_layer)
k_layer_name = k_layer.__class__.__name__
method_name = "convert_" + k_layer_name
@@ -78,18 +91,18 @@
dense
end
def convert_Activation(k_activation)
activation_name = k_activation.get_config[:activation].to_s
- case k_activation.get_config[:activation].to_s
+ activation = case k_activation.get_config[:activation].to_s
when "sigmoid"
- activation = DNN::Layers::Sigmoid.new
+ DNN::Layers::Sigmoid.new
when "tanh"
- activation = DNN::Layers::Tanh.new
+ DNN::Layers::Tanh.new
when "relu"
- activation = DNN::Layers::ReLU.new
+ DNN::Layers::ReLU.new
when "softmax"
- return nil
+ DNN::Layers::Softmax.new
else
raise DNNKerasModelConvertError.new("#{activation_name} activation do not support convert.")
end
build_dnn_layer(k_activation, activation)
activation