lib/torch/nn/module.rb in torch-rb-0.1.2 vs lib/torch/nn/module.rb in torch-rb-0.1.3
- old
+ new
@@ -1,18 +1,48 @@
module Torch
module NN
class Module
+ def initialize
+ @training = true
+ end
+
def inspect
str = String.new
str << "#{self.class.name}(\n"
modules.each do |name, mod|
str << " (#{name}): #{mod.inspect}\n"
end
str << ")"
end
+ def train(mode = true)
+ @training = mode
+
+ modules.each do |_, mod|
+ mod.train(mode)
+ end
+ end
+
+ def eval
+ train(false)
+ end
+
def call(*input)
forward(*input)
+ end
+
+ # modifies in-place
+ def to(device)
+ instance_variables.each do |name|
+ param = instance_variable_get(name)
+ if param.is_a?(Parameter)
+ instance_variable_set(name, Parameter.new(param.to(device)))
+ end
+ end
+ modules.each do |_, mod|
+ mod.to(device)
+ end
+ self
end
def parameters
params = []
instance_variables.each do |name|