Sha256: 44bba37c5b0577b9fe0581ae8e66d13aa0d19a9bea673ede686b211bc0748163

Contents?: true

Size: 1.64 KB

Versions: 43

Compression:

Stored size: 1.64 KB

Contents

# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
module Torch
  module Optim
    class Optimizer
      attr_reader :param_groups

      def initialize(params, defaults)
        @defaults = defaults
        @state = Hash.new { |hash, key| hash[key] = {} }
        @param_groups = []

        param_groups = params
        if param_groups.empty?
          raise ArgumentError, "optimizer got an empty parameter list"
        end
        if !param_groups[0].is_a?(Hash)
          param_groups = [{params: param_groups}]
        end

        param_groups.each do |param_group|
          add_param_group(param_group)
        end
      end

      def add_param_group(param_group)
        # TODO more advanced logic
        @param_groups << @defaults.merge(param_group)
      end

      def load_state_dict(state_dict)
        raise NotImplementedYet
      end

      def state_dict
        raise NotImplementedYet

        pack_group = lambda do |group|
          packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
          packed["params"] = group[:params].map { |p| p.object_id }
          packed
        end

        param_groups = @param_groups.map { |g| pack_group.call(g) }
        packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h

        {
          "state" => packed_state,
          "param_groups" => param_groups
        }
      end

      def zero_grad
        @param_groups.each do |group|
          group[:params].each do |p|
            if p.grad
              p.grad.detach!
              p.grad.zero!
            end
          end
        end
      end
    end
  end
end

Version data entries

43 entries across 43 versions & 1 rubygems

Version Path
torch-rb-0.9.1 lib/torch/optim/optimizer.rb
torch-rb-0.9.0 lib/torch/optim/optimizer.rb
torch-rb-0.8.3 lib/torch/optim/optimizer.rb
torch-rb-0.8.2 lib/torch/optim/optimizer.rb
torch-rb-0.8.1 lib/torch/optim/optimizer.rb
torch-rb-0.8.0 lib/torch/optim/optimizer.rb
torch-rb-0.7.0 lib/torch/optim/optimizer.rb
torch-rb-0.6.0 lib/torch/optim/optimizer.rb
torch-rb-0.5.3 lib/torch/optim/optimizer.rb
torch-rb-0.5.2 lib/torch/optim/optimizer.rb
torch-rb-0.5.1 lib/torch/optim/optimizer.rb
torch-rb-0.5.0 lib/torch/optim/optimizer.rb
torch-rb-0.4.2 lib/torch/optim/optimizer.rb
torch-rb-0.4.1 lib/torch/optim/optimizer.rb
torch-rb-0.4.0 lib/torch/optim/optimizer.rb
torch-rb-0.3.7 lib/torch/optim/optimizer.rb
torch-rb-0.3.6 lib/torch/optim/optimizer.rb
torch-rb-0.3.5 lib/torch/optim/optimizer.rb
torch-rb-0.3.4 lib/torch/optim/optimizer.rb
torch-rb-0.3.3 lib/torch/optim/optimizer.rb