lib/torch/optim/optimizer.rb in torch-rb-0.1.2 vs lib/torch/optim/optimizer.rb in torch-rb-0.1.3

- old
+ new

@@ -1,6 +1,62 @@ +# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py module Torch module Optim class Optimizer + attr_reader :param_groups + + def initialize(params, defaults) + @defaults = defaults + @state = Hash.new { |hash, key| hash[key] = {} } + @param_groups = [] + + param_groups = params + if param_groups.empty? + raise ArgumentError, "optimizer got an empty parameter list" + end + if !param_groups[0].is_a?(Hash) + param_groups = [{params: param_groups}] + end + + param_groups.each do |param_group| + add_param_group(param_group) + end + end + + def add_param_group(param_group) + # TODO more advanced logic + @param_groups << @defaults.merge(param_group) + end + + def load_state_dict(state_dict) + raise NotImplementedYet + end + + def state_dict + pack_group = lambda do |group| + packed = group.select { |k, _| k != :params }.to_h + packed[:params] = group[:params].map { |p| p.object_id } + packed + end + + param_groups = @param_groups.map { |g| pack_group.call(g) } + packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h + + { + state: packed_state, + param_groups: param_groups + } + end + + def zero_grad + @param_groups.each do |group| + group[:params].each do |p| + if p.grad + p.grad.detach! + p.grad.zero! + end + end + end + end end end end