Sha256: 7c4829c134f2c9cd473a2e641957d057d91b49b22898bd8c59708bf0cd4da40d

Contents?: true

Size: 590 Bytes

Versions: 4

Compression:

Stored size: 590 Bytes

Contents

module Torch
  module NN
    class Sequential < Module
      def initialize(*args)
        @modules = {}
        # TODO support hash arg (named modules)
        args.each_with_index do |mod, idx|
          add_module(idx.to_s, mod)
        end
      end

      def add_module(name, mod)
        # TODO add checks
        @modules[name] = mod
      end

      def forward(input)
        @modules.values.each do |mod|
          input = mod.call(input)
        end
        input
      end

      def parameters
        @modules.flat_map { |_, mod| mod.parameters }
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
torch-rb-0.1.3 lib/torch/nn/sequential.rb
torch-rb-0.1.2 lib/torch/nn/sequential.rb
torch-rb-0.1.1 lib/torch/nn/sequential.rb
torch-rb-0.1.0 lib/torch/nn/sequential.rb