lib/dnn/core/model.rb in ruby-dnn-0.10.2 vs lib/dnn/core/model.rb in ruby-dnn-0.10.3

- old
+ new

@@ -118,11 +118,10 @@ @compiled = true layers_check @optimizer = optimizer @loss_func = loss_func build - layers_shape_check end # Set optimizer and loss_func to model and recompile. But does not build layers. # @param [DNN::Optimizers::Optimizer] optimizer Optimizer to use for learning. # @param [DNN::Losses::Loss] loss_func Loss function to use for learning. @@ -135,11 +134,10 @@ end @compiled = true layers_check @optimizer = optimizer @loss_func = loss_func - layers_shape_check end def build(super_model = nil) @super_model = super_model shape = if super_model @@ -423,29 +421,9 @@ unless @layers.first.input_shape == x.shape[1..-1] raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.input_shape}.") end if y && @layers.last.output_shape != y.shape[1..-1] raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.output_shape}.") - end - end - - def layers_shape_check - @layers.each.with_index do |layer, i| - prev_shape = layer.input_shape - if layer.is_a?(Layers::Dense) - if prev_shape.length != 1 - raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.") - end - elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D) - if prev_shape.length != 3 - raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.") - end - elsif layer.is_a?(Layers::RNN) - if prev_shape.length != 2 - layer_name = layer.class.name.match("\:\:(.+)$")[1] - raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.") - end - end end end def check_xy_type(x, y = nil) unless x.is_a?(Xumo::SFloat)