lib/torch/nn/module.rb in torch-rb-0.1.2 vs lib/torch/nn/module.rb in torch-rb-0.1.3

- old
+ new

@@ -1,18 +1,48 @@ module Torch module NN class Module + def initialize + @training = true + end + def inspect str = String.new str << "#{self.class.name}(\n" modules.each do |name, mod| str << " (#{name}): #{mod.inspect}\n" end str << ")" end + def train(mode = true) + @training = mode + + modules.each do |_, mod| + mod.train(mode) + end + end + + def eval + train(false) + end + def call(*input) forward(*input) + end + + # modifies in-place + def to(device) + instance_variables.each do |name| + param = instance_variable_get(name) + if param.is_a?(Parameter) + instance_variable_set(name, Parameter.new(param.to(device))) + end + end + modules.each do |_, mod| + mod.to(device) + end + self end def parameters params = [] instance_variables.each do |name|