lib/dnn/core/model.rb in ruby-dnn-0.6.3 vs lib/dnn/core/model.rb in ruby-dnn-0.6.4
- old
+ new
@@ -1,12 +1,13 @@
require "json"
module DNN
# This class deals with the model of the network.
class Model
- attr_accessor :layers
- attr_reader :optimizer
+ attr_accessor :layers # All layers possessed by the model
+ attr_accessor :trainable # Setting false prevents learning of parameters.
+ attr_reader :optimizer # Optimizer possessed by the model
def self.load(file_name)
Marshal.load(File.binread(file_name))
end
@@ -18,10 +19,11 @@
model
end
def initialize
@layers = []
+ @trainable = true
@optimizer = nil
@training = false
@compiled = false
end
@@ -138,10 +140,10 @@
def train_on_batch(x, y, &batch_proc)
x, y = batch_proc.call(x, y) if batch_proc
forward(x, true)
loss = @layers[-1].loss(y)
backward(y)
- @layers.each { |layer| layer.update if layer.respond_to?(:update) }
+ @layers.each { |layer| layer.update if @trainable && layer.is_a?(HasParamLayer) }
loss
end
def accurate(x, y, batch_size = 1, &batch_proc)
batch_size = batch_size >= x.shape[0] ? batch_size : x.shape[0]