lib/torch/optim/optimizer.rb in torch-rb-0.2.7 vs lib/torch/optim/optimizer.rb in torch-rb-0.3.0

- old
+ new

@@ -30,21 +30,23 @@ def load_state_dict(state_dict) raise NotImplementedYet end def state_dict + raise NotImplementedYet + pack_group = lambda do |group| - packed = group.select { |k, _| k != :params }.to_h - packed[:params] = group[:params].map { |p| p.object_id } + packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h + packed["params"] = group[:params].map { |p| p.object_id } packed end param_groups = @param_groups.map { |g| pack_group.call(g) } packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h { - state: packed_state, - param_groups: param_groups + "state" => packed_state, + "param_groups" => param_groups } end def zero_grad @param_groups.each do |group|