lib/torch/nn/module.rb in torch-rb-0.1.3 vs lib/torch/nn/module.rb in torch-rb-0.1.4

- old
+ new

@@ -1,85 +1,174 @@ module Torch module NN class Module def initialize @training = true + @parameters = {} + @buffers = {} + @modules = {} end - def inspect - str = String.new - str << "#{self.class.name}(\n" - modules.each do |name, mod| - str << " (#{name}): #{mod.inspect}\n" - end - str << ")" + def forward + raise NotImplementedError end - def train(mode = true) - @training = mode + def register_buffer(name, tensor) + # TODO add checks + @buffers[name] = tensor + end - modules.each do |_, mod| - mod.train(mode) + def register_parameter(name, param) + # TODO add checks + @parameters[name] = param + end + + def add_module(name, mod) + # TODO add checks + @modules[name] = mod + end + + def _apply(fn) + children.each do |mod| + mod._apply(fn) end + # TODO apply to more objects + self end - def eval - train(false) + def apply(fn) + children.each do |mod| + mod.apply(fn) + end + fn.call(self) + self end - def call(*input) - forward(*input) + def cuda(device: nil) + _apply ->(t) { t.cuda(device) } end + def cpu + _apply ->(t) { t.cpu } + end + + def type(dst_type) + _apply ->(t) { t.type(dst_type) } + end + + def float + _apply ->(t) { t.floating_point? ? t.float : t } + end + + def double + _apply ->(t) { t.floating_point? ? t.double : t } + end + + def half + _apply ->(t) { t.floating_point? ? t.half : t } + end + # modifies in-place def to(device) - instance_variables.each do |name| - param = instance_variable_get(name) - if param.is_a?(Parameter) - instance_variable_set(name, Parameter.new(param.to(device))) - end + convert = lambda do |t| + t.to(device) end - modules.each do |_, mod| - mod.to(device) - end - self + + _apply(convert) end + def call(*input) + forward(*input) + end + + def state_dict + raise NotImplementedYet + end + def parameters params = [] instance_variables.each do |name| param = instance_variable_get(name) params << param if param.is_a?(Parameter) end params + modules.flat_map { |_, mod| mod.parameters } end + def children + @modules.values + end + + def modules + modules = {} + instance_variables.each do |name| + mod = instance_variable_get(name) + modules[name[1..-1]] = mod if mod.is_a?(Module) + end + @modules.merge(modules) + end + + def train(mode = true) + @training = mode + children.each do |mod| + mod.train(mode) + end + self + end + + def eval + train(false) + end + + def requires_grad!(requires_grad: true) + parameters.each do |p| + p.requires_grad!(requires_grad) + end + self + end + def zero_grad parameters.each do |param| if param.grad param.grad.detach! param.grad.zero! end end end + def share_memory + _apply ->(t) { t.share_memory! } + end + + def inspect + name = self.class.name.split("::").last + if modules.empty? + "#{name}(#{extra_inspect})" + else + str = String.new + str << "#{name}(\n" + modules.each do |name, mod| + str << " (#{name}): #{mod.inspect}\n" + end + str << ")" + end + end + def method_missing(method, *args, &block) modules[method.to_s] || super end def respond_to?(method, include_private = false) modules.key?(method.to_s) || super end private - def modules - modules = {} - instance_variables.each do |name| - mod = instance_variable_get(name) - modules[name[1..-1]] = mod if mod.is_a?(Module) - end - modules + def extra_inspect + nil + end + + def format(str, *vars) + str % vars.map(&:inspect) end end end end