Sha256: 7497cc74005538a03c85f77487d25b60d125cd2f23834370489027ab1b6f9ef9

Contents?: true

Size: 1.57 KB

Versions: 14

Compression:

Stored size: 1.57 KB

Contents

# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
module Torch
  module Optim
    class Optimizer
      attr_reader :param_groups

      def initialize(params, defaults)
        @defaults = defaults
        @state = Hash.new { |hash, key| hash[key] = {} }
        @param_groups = []

        param_groups = params
        if param_groups.empty?
          raise ArgumentError, "optimizer got an empty parameter list"
        end
        if !param_groups[0].is_a?(Hash)
          param_groups = [{params: param_groups}]
        end

        param_groups.each do |param_group|
          add_param_group(param_group)
        end
      end

      def add_param_group(param_group)
        # TODO more advanced logic
        @param_groups << @defaults.merge(param_group)
      end

      def load_state_dict(state_dict)
        raise NotImplementedYet
      end

      def state_dict
        pack_group = lambda do |group|
          packed = group.select { |k, _| k != :params }.to_h
          packed[:params] = group[:params].map { |p| p.object_id }
          packed
        end

        param_groups = @param_groups.map { |g| pack_group.call(g) }
        packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h

        {
          state: packed_state,
          param_groups: param_groups
        }
      end

      def zero_grad
        @param_groups.each do |group|
          group[:params].each do |p|
            if p.grad
              p.grad.detach!
              p.grad.zero!
            end
          end
        end
      end
    end
  end
end

Version data entries

14 entries across 14 versions & 1 rubygems

Version Path
torch-rb-0.2.7 lib/torch/optim/optimizer.rb
torch-rb-0.2.6 lib/torch/optim/optimizer.rb
torch-rb-0.2.5 lib/torch/optim/optimizer.rb
torch-rb-0.2.4 lib/torch/optim/optimizer.rb
torch-rb-0.2.3 lib/torch/optim/optimizer.rb
torch-rb-0.2.2 lib/torch/optim/optimizer.rb
torch-rb-0.2.1 lib/torch/optim/optimizer.rb
torch-rb-0.2.0 lib/torch/optim/optimizer.rb
torch-rb-0.1.8 lib/torch/optim/optimizer.rb
torch-rb-0.1.7 lib/torch/optim/optimizer.rb
torch-rb-0.1.6 lib/torch/optim/optimizer.rb
torch-rb-0.1.5 lib/torch/optim/optimizer.rb
torch-rb-0.1.4 lib/torch/optim/optimizer.rb
torch-rb-0.1.3 lib/torch/optim/optimizer.rb