lib/torch/nn/module.rb in torch-rb-0.1.6 vs lib/torch/nn/module.rb in torch-rb-0.1.7
- old
+ new
@@ -77,14 +77,22 @@
end
_apply(convert)
end
- def call(*input)
- forward(*input)
+ def call(*input, **kwargs)
+ forward(*input, **kwargs)
end
- def state_dict
+ def state_dict(destination: nil)
+ destination ||= {}
+ named_parameters.each do |k, v|
+ destination[k] = v
+ end
+ destination
+ end
+
+ def load_state_dict(state_dict)
raise NotImplementedYet
end
def parameters
named_parameters.values