lib/dnn/core/model.rb in ruby-dnn-0.10.2 vs lib/dnn/core/model.rb in ruby-dnn-0.10.3
- old
+ new
@@ -118,11 +118,10 @@
@compiled = true
layers_check
@optimizer = optimizer
@loss_func = loss_func
build
- layers_shape_check
end
# Set optimizer and loss_func to model and recompile. But does not build layers.
# @param [DNN::Optimizers::Optimizer] optimizer Optimizer to use for learning.
# @param [DNN::Losses::Loss] loss_func Loss function to use for learning.
@@ -135,11 +134,10 @@
end
@compiled = true
layers_check
@optimizer = optimizer
@loss_func = loss_func
- layers_shape_check
end
def build(super_model = nil)
@super_model = super_model
shape = if super_model
@@ -423,29 +421,9 @@
unless @layers.first.input_shape == x.shape[1..-1]
raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.input_shape}.")
end
if y && @layers.last.output_shape != y.shape[1..-1]
raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.output_shape}.")
- end
- end
-
- def layers_shape_check
- @layers.each.with_index do |layer, i|
- prev_shape = layer.input_shape
- if layer.is_a?(Layers::Dense)
- if prev_shape.length != 1
- raise DNN_ShapeError.new("layer index(#{i}) Dense: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 1 dimensional.")
- end
- elsif layer.is_a?(Layers::Conv2D) || layer.is_a?(Layers::MaxPool2D)
- if prev_shape.length != 3
- raise DNN_ShapeError.new("layer index(#{i}) Conv2D: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
- end
- elsif layer.is_a?(Layers::RNN)
- if prev_shape.length != 2
- layer_name = layer.class.name.match("\:\:(.+)$")[1]
- raise DNN_ShapeError.new("layer index(#{i}) #{layer_name}: The shape of the previous layer is #{prev_shape}. The shape of the previous layer must be 3 dimensional.")
- end
- end
end
end
def check_xy_type(x, y = nil)
unless x.is_a?(Xumo::SFloat)