lib/torch/nn/rnn_base.rb in torch-rb-0.1.5 vs lib/torch/nn/rnn_base.rb in torch-rb-0.1.6
- old
+ new
@@ -88,16 +88,17 @@
Init.uniform!(weight, a: -stdv, b: stdv)
end
end
def permute_hidden(hx, permutation)
+ if permutation.nil?
+ return hx
+ end
raise NotImplementedYet
end
def forward(input, hx: nil)
- raise NotImplementedYet
-
is_packed = false # TODO isinstance(input, PackedSequence)
if is_packed
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = max_batch_size.to_i
@@ -118,12 +119,12 @@
hx = permute_hidden(hx, sorted_indices)
end
check_forward_args(input, hx, batch_sizes)
_rnn_impls = {
- "RNN_TANH" => Torch.method(:_rnn_tanh),
- "RNN_RELU" => Torch.method(:_rnn_relu)
+ "RNN_TANH" => Torch.method(:rnn_tanh),
+ "RNN_RELU" => Torch.method(:rnn_relu)
}
_impl = _rnn_impls[@mode]
if batch_sizes.nil?
result = _impl.call(input, hx, _get_flat_weights, @bias, @num_layers,
@dropout, @training, @bidirectional, @batch_first)
@@ -146,9 +147,52 @@
s = String.new("%{input_size}, %{hidden_size}")
if @num_layers != 1
s += ", num_layers: %{num_layers}"
end
format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
+ end
+
+ private
+
+ def _flat_weights
+ @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
+ end
+
+ def _get_flat_weights
+ _flat_weights
+ end
+
+ def check_input(input, batch_sizes)
+ expected_input_dim = !batch_sizes.nil? ? 2 : 3
+ if input.dim != expected_input_dim
+ raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
+ end
+ if @input_size != input.size(-1)
+ raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
+ end
+ end
+
+ def get_expected_hidden_size(input, batch_sizes)
+ if !batch_sizes.nil?
+ mini_batch = batch_sizes[0]
+ mini_batch = mini_batch.to_i
+ else
+ mini_batch = @batch_first ? input.size(0) : input.size(1)
+ end
+ num_directions = @bidirectional ? 2 : 1
+ [@num_layers * num_directions, mini_batch, @hidden_size]
+ end
+
+ def check_hidden_size(hx, expected_hidden_size)
+ if hx.size != expected_hidden_size
+ raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
+ end
+ end
+
+ def check_forward_args(input, hidden, batch_sizes)
+ check_input(input, batch_sizes)
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
+ check_hidden_size(hidden, expected_hidden_size)
end
end
end
end