Sha256: 7c4829c134f2c9cd473a2e641957d057d91b49b22898bd8c59708bf0cd4da40d
Contents?: true
Size: 590 Bytes
Versions: 4
Compression:
Stored size: 590 Bytes
Contents
module Torch module NN class Sequential < Module def initialize(*args) @modules = {} # TODO support hash arg (named modules) args.each_with_index do |mod, idx| add_module(idx.to_s, mod) end end def add_module(name, mod) # TODO add checks @modules[name] = mod end def forward(input) @modules.values.each do |mod| input = mod.call(input) end input end def parameters @modules.flat_map { |_, mod| mod.parameters } end end end end
Version data entries
4 entries across 4 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.3 | lib/torch/nn/sequential.rb |
torch-rb-0.1.2 | lib/torch/nn/sequential.rb |
torch-rb-0.1.1 | lib/torch/nn/sequential.rb |
torch-rb-0.1.0 | lib/torch/nn/sequential.rb |