lib/torch/nn/rnn_base.rb in torch-rb-0.1.5 vs lib/torch/nn/rnn_base.rb in torch-rb-0.1.6

- old
+ new

@@ -88,16 +88,17 @@ Init.uniform!(weight, a: -stdv, b: stdv) end end def permute_hidden(hx, permutation) + if permutation.nil? + return hx + end raise NotImplementedYet end def forward(input, hx: nil) - raise NotImplementedYet - is_packed = false # TODO isinstance(input, PackedSequence) if is_packed input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = max_batch_size.to_i @@ -118,12 +119,12 @@ hx = permute_hidden(hx, sorted_indices) end check_forward_args(input, hx, batch_sizes) _rnn_impls = { - "RNN_TANH" => Torch.method(:_rnn_tanh), - "RNN_RELU" => Torch.method(:_rnn_relu) + "RNN_TANH" => Torch.method(:rnn_tanh), + "RNN_RELU" => Torch.method(:rnn_relu) } _impl = _rnn_impls[@mode] if batch_sizes.nil? result = _impl.call(input, hx, _get_flat_weights, @bias, @num_layers, @dropout, @training, @bidirectional, @batch_first) @@ -146,9 +147,52 @@ s = String.new("%{input_size}, %{hidden_size}") if @num_layers != 1 s += ", num_layers: %{num_layers}" end format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers) + end + + private + + def _flat_weights + @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact + end + + def _get_flat_weights + _flat_weights + end + + def check_input(input, batch_sizes) + expected_input_dim = !batch_sizes.nil? ? 2 : 3 + if input.dim != expected_input_dim + raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}" + end + if @input_size != input.size(-1) + raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}" + end + end + + def get_expected_hidden_size(input, batch_sizes) + if !batch_sizes.nil? + mini_batch = batch_sizes[0] + mini_batch = mini_batch.to_i + else + mini_batch = @batch_first ? input.size(0) : input.size(1) + end + num_directions = @bidirectional ? 2 : 1 + [@num_layers * num_directions, mini_batch, @hidden_size] + end + + def check_hidden_size(hx, expected_hidden_size) + if hx.size != expected_hidden_size + raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}" + end + end + + def check_forward_args(input, hidden, batch_sizes) + check_input(input, batch_sizes) + expected_hidden_size = get_expected_hidden_size(input, batch_sizes) + check_hidden_size(hidden, expected_hidden_size) end end end end