Sha256: fd643720208bae0c51cc4c337c97bf1c91ec4774781073695c4e6e638a39e401

Contents?: true

Size: 1.23 KB

Versions: 2

Compression:

Stored size: 1.23 KB

Contents

module Torch
  module NN
    class Module
      def inspect
        str = String.new
        str << "#{self.class.name}(\n"
        modules.each do |name, mod|
          str << "  (#{name}): #{mod.inspect}\n"
        end
        str << ")"
      end

      def call(*input)
        forward(*input)
      end

      def parameters
        params = []
        instance_variables.each do |name|
          param = instance_variable_get(name)
          params << param if param.is_a?(Parameter)
        end
        params + modules.flat_map { |_, mod| mod.parameters }
      end

      def zero_grad
        parameters.each do |param|
          if param.grad
            raise Error, "Not supported yet"
            param.grad.detach!
            param.grad.zero!
          end
        end
      end

      def method_missing(method, *args, &block)
        modules[method.to_s] || super
      end

      def respond_to?(method, include_private = false)
        modules.key?(method.to_s) || super
      end

      private

      def modules
        modules = {}
        instance_variables.each do |name|
          mod = instance_variable_get(name)
          modules[name[1..-1]] = mod if mod.is_a?(Module)
        end
        modules
      end
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
torch-rb-0.1.1 lib/torch/nn/module.rb
torch-rb-0.1.0 lib/torch/nn/module.rb