lib/dnn/core/model.rb in ruby-dnn-0.6.3 vs lib/dnn/core/model.rb in ruby-dnn-0.6.4

- old
+ new

@@ -1,12 +1,13 @@ require "json" module DNN # This class deals with the model of the network. class Model - attr_accessor :layers - attr_reader :optimizer + attr_accessor :layers # All layers possessed by the model + attr_accessor :trainable # Setting false prevents learning of parameters. + attr_reader :optimizer # Optimizer possessed by the model def self.load(file_name) Marshal.load(File.binread(file_name)) end @@ -18,10 +19,11 @@ model end def initialize @layers = [] + @trainable = true @optimizer = nil @training = false @compiled = false end @@ -138,10 +140,10 @@ def train_on_batch(x, y, &batch_proc) x, y = batch_proc.call(x, y) if batch_proc forward(x, true) loss = @layers[-1].loss(y) backward(y) - @layers.each { |layer| layer.update if layer.respond_to?(:update) } + @layers.each { |layer| layer.update if @trainable && layer.is_a?(HasParamLayer) } loss end def accurate(x, y, batch_size = 1, &batch_proc) batch_size = batch_size >= x.shape[0] ? batch_size : x.shape[0]