Sha256: adaa1469c8b49256d52d529d28c0a3011bf489b85d0ac052d49826c8ec91c6b6

Contents?: true

Size: 1.89 KB

Versions: 1

Compression:

Stored size: 1.89 KB

Contents

class GLM::Base
  @@initial_weight = 1
  def initialize(x,y,alpha = 0.05)
    @x = x
    @y = y
    @@alpha = alpha
    @theta = GSL::Vector.alloc(Array.new(x.size2, @@initial_weight))
  end

  #Log partition function <b>a(eta)</b>, intended to be overriden
  def a
    raise 'Log partition function a(eta) undefined'
  end

  #intended to be overriden
  def b
    raise 'b undefined'
  end

  def format(x)
    if x.is_a? GSL::Vector
      return output(x)
    elsif x.is_a? GSL::Matrix
      tmp = GSL::Vector.alloc x.size1
      (0...x.size1).each {|i|
        tmp[i]= output(x.row(i))}
      return tmp
    end
  end

  # Estimator
  # =Arguments:
  #   x: a feature vector in Array
  # =Returns:
  #   Estimation
  def est(x)
    format(x)
  end

  #Output estimation from E(y|theta,x)
  #Need overriding, except for plain linear regression
  def output(x)
    return h(x)
  end

  #Natural parameter eta
  def eta(x)
    tmp = @theta * x.transpose
    return tmp
  end

  
  #Sufficient statistic <b>T</b>
  def T
    return @y
  end

  #Canonical reponse function, intended to be overriden
  def self.g(eta)
    raise 'Canonical reponse function g(eta) undefined'
  end

  #Gradient on one sample
  def gradient(x,y,v)
    tmp = h(v)
    res = (y - tmp) * x
    return res
  end

  # Hypothesis function, outputs E(y|theta, x), mean of y given x parameterized by theta
  # =Parameters:
  #   x: a feature vector
  # =Returns:
  #   E(y|theta, x)
  def h(x)
    tmp = eta(x)
    return self.class.g(tmp)
  end

  #A step based on one sample in stochastic gradient descent
  def single_update()
    
  end

  #One complete loop of stochastic gradient descend
  def sto_update()
    (0...(@x.size1)).each do |i|
      (0...(@x.size2)).each do |j|
        updates = gradient(@x[i,j], @y[i], @x.row(i))
        @theta[j] = @theta[j] + @@alpha * updates
      end
    end
    pp @theta
  end

  def theta()
    return @theta
  end

end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
glm-0.0.2 lib/glm/base.rb