lib/torch/optim/optimizer.rb in torch-rb-0.2.7 vs lib/torch/optim/optimizer.rb in torch-rb-0.3.0
- old
+ new
@@ -30,21 +30,23 @@
def load_state_dict(state_dict)
raise NotImplementedYet
end
def state_dict
+ raise NotImplementedYet
+
pack_group = lambda do |group|
- packed = group.select { |k, _| k != :params }.to_h
- packed[:params] = group[:params].map { |p| p.object_id }
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
+ packed["params"] = group[:params].map { |p| p.object_id }
packed
end
param_groups = @param_groups.map { |g| pack_group.call(g) }
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
{
- state: packed_state,
- param_groups: param_groups
+ "state" => packed_state,
+ "param_groups" => param_groups
}
end
def zero_grad
@param_groups.each do |group|