lib/torch/nn/module.rb in torch-rb-0.1.6 vs lib/torch/nn/module.rb in torch-rb-0.1.7

- old
+ new

@@ -77,14 +77,22 @@ end _apply(convert) end - def call(*input) - forward(*input) + def call(*input, **kwargs) + forward(*input, **kwargs) end - def state_dict + def state_dict(destination: nil) + destination ||= {} + named_parameters.each do |k, v| + destination[k] = v + end + destination + end + + def load_state_dict(state_dict) raise NotImplementedYet end def parameters named_parameters.values