lib/dnn/keras-model-convertor.rb in ruby-dnn-1.1.0 vs lib/dnn/keras-model-convertor.rb in ruby-dnn-1.1.1

- old
+ new

@@ -13,10 +13,20 @@ pyimport :keras pyfrom :"keras.models", import: :Sequential pyfrom :"keras.layers", import: [:Dense, :Dropout, :Conv2D, :Activation, :MaxPooling2D, :Flatten] pyfrom :"keras.layers.normalization", import: :BatchNormalization +module DNN + module Layers + class Softmax < Layer + def forward(x) + Exp.(x) / Sum.(Exp.(x), axis: 1) + end + end + end +end + class DNNKerasModelConvertError < DNN::DNNError; end class KerasModelConvertor pyfrom :"keras.models", import: :load_model @@ -32,22 +42,25 @@ def convert unless @k_model.__class__.__name__ == "Sequential" raise DNNKerasModelConvertError.new("#{@k_model.__class__.__name__} models do not support convert.") end - dnn_model = DNN::Models::Sequential.new - @k_model.layers.each do |k_layer| - dnn_layer = layer_convert(k_layer) - dnn_model << dnn_layer if dnn_layer - end + layers = convert_layers(@k_model.layers) input_shape = @k_model.layers[0].input_shape.to_a[1..-1] input_layer = DNN::Layers::InputLayer.new(input_shape) input_layer.build(input_shape) - dnn_model.insert(0, input_layer) + layers.unshift(input_layer) + dnn_model = DNN::Models::Sequential.new(layers) dnn_model end + def convert_layers(k_layers) + k_layers.map do |k_layer| + layer_convert(k_layer) + end + end + private def layer_convert(k_layer) k_layer_name = k_layer.__class__.__name__ method_name = "convert_" + k_layer_name @@ -78,18 +91,18 @@ dense end def convert_Activation(k_activation) activation_name = k_activation.get_config[:activation].to_s - case k_activation.get_config[:activation].to_s + activation = case k_activation.get_config[:activation].to_s when "sigmoid" - activation = DNN::Layers::Sigmoid.new + DNN::Layers::Sigmoid.new when "tanh" - activation = DNN::Layers::Tanh.new + DNN::Layers::Tanh.new when "relu" - activation = DNN::Layers::ReLU.new + DNN::Layers::ReLU.new when "softmax" - return nil + DNN::Layers::Softmax.new else raise DNNKerasModelConvertError.new("#{activation_name} activation do not support convert.") end build_dnn_layer(k_activation, activation) activation