Sha256: 2efda824a1b8fac98af01774b48d5ffe348d771d39a7b84b8a87cda9459eb76e

Contents?: true

Size: 1.88 KB

Versions: 2

Compression:

Stored size: 1.88 KB

Contents

class GLM::Base

  def initialize(x,y,alpha = 0.1)
    @x = x
    @y = y
    @@alpha = alpha
    @theta = Array.new(x.column_size,1)
  end

  #Log partition function <b>a(eta)</b>, intended to be overriden
  def a
    raise 'Log partition function a(eta) undefined'
  end

  #intended to be overriden
  def b
    raise 'b undefined'
  end

  def format(x)
    if x.is_a? Array
      if x[0].is_a?(Array)
        x.map {|e|
          output(e)}
      else
        output(x)
      end
    #Assuming x.is_a?(Matrix) == true
    else
      x.row_vectors.map {|e|
        output(Matrix.row_vector(e))
      }
    end
  end


  # Estimator
  # =Arguments:
  #   x: a feature vector in Array
  # =Returns:
  #   Estimation
  def est(x)
    format(x)
  end

  #Output estimation from E(y|theta,x)
  #Need overriding, except for plain linear regression
  def output(x)
    return h(x.t)
  end

  #Natural parameter eta
  def eta(x)
    tmp = (Matrix.column_vector(@theta).t * x)[0,0]
    return tmp
  end

  
  #Sufficient statistic <b>T</b>
  def T
    return @y
  end

  #Canonical reponse function, intended to be overriden
  def self.g(eta)
    raise 'Canonical reponse function g(eta) undefined'
  end

  #Gradient on one sample
  def gradient(x,y,v)
    tmp = h(v)
    return (y - tmp) * x
  end

  # Hypothesis function, outputs E(y|theta, x), mean of y given x parameterized by theta
  # =Parameters:
  #   x: a feature vector
  # =Returns:
  #   E(y|theta, x)
  def h(x)
    tmp = eta(x)
    return self.class.g(tmp)
  end

  #A step based on one sample in stochastic gradient descent
  def single_update()
    
  end

  #One complete loop of stochastic gradient descend
  def sto_update()
    (0...(@x.row_size)).each do |i|
      (0...(@x.column_size)).each do |j|
        @theta[j] += @@alpha * gradient(@x[i,j], @y[i,0], Matrix.column_vector(@x.row(i)))
      end
    end
  end

  def theta()
    return @theta
  end

end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
glm-0.0.1 lib/glm/base.rb
glm-0.0.0 lib/glm/base.rb