Sha256: 1afe0676c0c3aa8f3c4b5063648bbd0e740d661e61a91679c2342af326c0ff11

Contents?: true

Size: 1.91 KB

Versions: 32

Compression:

Stored size: 1.91 KB

Contents

# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
module Torch
  module Optim
    class Rprop < Optimizer
      def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
        raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
        raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1

        defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
        super(params, defaults)
      end

      def step(closure = nil)
        loss = nil
        if closure
          loss = closure.call
        end

        @param_groups.each do |group|
          group[:params].each do |p|
            next unless p.grad
            grad = p.grad.data
            if grad.sparse?
              raise Error, "Rprop does not support sparse gradients"
            end
            state = @state[p]

            # State initialization
            if state.size == 0
              state[:step] = 0
              state[:prev] = Torch.zeros_like(p.data)
              state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
            end

            etaminus, etaplus = group[:etas]
            step_size_min, step_size_max = group[:step_sizes]
            step_size = state[:step_size]

            state[:step] += 1

            sign = grad.mul(state[:prev]).sign
            sign[sign.gt(0)] = etaplus
            sign[sign.lt(0)] = etaminus
            sign[sign.eq(0)] = 1

            # update stepsizes with step size updates
            step_size.mul!(sign).clamp!(step_size_min, step_size_max)

            # for dir<0, dfdx=0
            # for dir>=0 dfdx=dfdx
            grad = grad.clone
            grad[sign.eq(etaminus)] = 0

            # update parameters
            p.data.addcmul!(grad.sign, step_size, value: -1)

            state[:prev].copy!(grad)
          end
        end

        loss
      end
    end
  end
end

Version data entries

32 entries across 32 versions & 1 rubygems

Version Path
torch-rb-0.18.0 lib/torch/optim/rprop.rb
torch-rb-0.17.1 lib/torch/optim/rprop.rb
torch-rb-0.17.0 lib/torch/optim/rprop.rb
torch-rb-0.16.0 lib/torch/optim/rprop.rb
torch-rb-0.15.0 lib/torch/optim/rprop.rb
torch-rb-0.14.1 lib/torch/optim/rprop.rb
torch-rb-0.14.0 lib/torch/optim/rprop.rb
torch-rb-0.13.2 lib/torch/optim/rprop.rb
torch-rb-0.13.1 lib/torch/optim/rprop.rb
torch-rb-0.13.0 lib/torch/optim/rprop.rb
torch-rb-0.12.2 lib/torch/optim/rprop.rb
torch-rb-0.12.1 lib/torch/optim/rprop.rb
torch-rb-0.12.0 lib/torch/optim/rprop.rb
torch-rb-0.11.2 lib/torch/optim/rprop.rb
torch-rb-0.11.1 lib/torch/optim/rprop.rb
torch-rb-0.11.0 lib/torch/optim/rprop.rb
torch-rb-0.10.2 lib/torch/optim/rprop.rb
torch-rb-0.10.1 lib/torch/optim/rprop.rb
torch-rb-0.10.0 lib/torch/optim/rprop.rb
torch-rb-0.9.2 lib/torch/optim/rprop.rb