lib/torch/nn/module.rb in torch-rb-0.1.5 vs lib/torch/nn/module.rb in torch-rb-0.1.6
- old
+ new
@@ -1,8 +1,10 @@
module Torch
module NN
class Module
+ include Utils
+
def initialize
@training = true
@parameters = {}
@buffers = {}
@modules = {}
@@ -13,10 +15,11 @@
end
def register_buffer(name, tensor)
# TODO add checks
@buffers[name] = tensor
+ instance_variable_set("@#{name}", tensor)
end
def register_parameter(name, param)
# TODO add checks
@parameters[name] = param
@@ -213,9 +216,13 @@
vars.map(&:inspect)
else
options.map { |k, v| [k, v.inspect] }.to_h
end
str % vars
+ end
+
+ def dict
+ instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
end
end
end
end