Sha256: 887e3bf85078d4859afe557cdf4cc187ec0134daa8e5b1d6476174bbbefa5b3e

Contents?: true

Size: 1.64 KB

Versions: 54

Compression:

Stored size: 1.64 KB

Contents

module Torch
  module NN
    class GRU < RNNBase
      def initialize(*args, **options)
        super("GRU", *args, **options)
      end

      def run_impl(input, hx, batch_sizes)
        if batch_sizes.nil?
          Torch.gru(input, hx, _get_flat_weights, @bias, @num_layers,
                             @dropout, @training, @bidirectional, @batch_first)
        else
          Torch.gru(input, batch_sizes, hx, _get_flat_weights, @bias,
                             @num_layers, @dropout, @training, @bidirectional)
        end
      end

      def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
        if hx.nil?
          num_directions = @bidirectional ? 2 : 1
          hx = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
        else
          # Each batch of the hidden state should match the input sequence that
          # the user believes he/she is passing in.
          hx = permute_hidden(hx, sorted_indices)
        end

        check_forward_args(input, hx, batch_sizes)
        result = run_impl(input, hx, batch_sizes)
        output = result[0]
        hidden = result[1]
        [output, hidden]
      end

      def forward_tensor(input, hx: nil)
        batch_sizes = nil
        max_batch_size = @batch_first ? input.size(0) : input.size(1)
        sorted_indices = nil
        unsorted_indices = nil
        output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
        [output, permute_hidden(hidden, unsorted_indices)]
      end

      def forward(input, hx: nil)
        forward_tensor(input, hx: hx)
      end
    end
  end
end

Version data entries

54 entries across 54 versions & 1 rubygems

Version Path
torch-rb-0.3.2 lib/torch/nn/gru.rb
torch-rb-0.3.1 lib/torch/nn/gru.rb
torch-rb-0.3.0 lib/torch/nn/gru.rb
torch-rb-0.2.7 lib/torch/nn/gru.rb
torch-rb-0.2.6 lib/torch/nn/gru.rb
torch-rb-0.2.5 lib/torch/nn/gru.rb
torch-rb-0.2.4 lib/torch/nn/gru.rb
torch-rb-0.2.3 lib/torch/nn/gru.rb
torch-rb-0.2.2 lib/torch/nn/gru.rb
torch-rb-0.2.1 lib/torch/nn/gru.rb
torch-rb-0.2.0 lib/torch/nn/gru.rb
torch-rb-0.1.8 lib/torch/nn/gru.rb
torch-rb-0.1.7 lib/torch/nn/gru.rb
torch-rb-0.1.6 lib/torch/nn/gru.rb