Sha256: b397322de83a27159ee8e34accc97afa57d1b44328bacd3116f136d8a7ed3fec

Contents?: true

Size: 1.2 KB

Versions: 6

Compression:

Stored size: 1.2 KB

Contents

# frozen_string_literal: true

require 'rumale/base/estimator'

module Rumale
  # This module consists of the classes that implement generalized linear models.
  module LinearModel
    # BaseEstimator is an abstract class for implementation of linear model. This class is used internally.
    class BaseEstimator < Rumale::Base::Estimator
      # Return the weight vector.
      # @return [Numo::DFloat] (shape: [n_outputs/n_classes, n_features])
      attr_reader :weight_vec

      # Return the bias term (a.k.a. intercept).
      # @return [Numo::DFloat] (shape: [n_outputs/n_classes])
      attr_reader :bias_term

      # Create an initial linear model.

      private

      def expand_feature(x)
        n_samples = x.shape[0]
        Numo::NArray.hstack([x, Numo::DFloat.ones([n_samples, 1]) * @params[:bias_scale]])
      end

      def split_weight(w)
        if w.ndim == 1
          if fit_bias?
            [w[0...-1].dup, w[-1]]
          else
            [w, 0.0]
          end
        elsif fit_bias?
          [w[true, 0...-1].dup, w[true, -1].dup]
        else
          [w, Numo::DFloat.zeros(w.shape[0])]
        end
      end

      def fit_bias?
        @params[:fit_bias] == true
      end
    end
  end
end

Version data entries

6 entries across 6 versions & 1 rubygems

Version Path
rumale-linear_model-0.29.0 lib/rumale/linear_model/base_estimator.rb
rumale-linear_model-0.28.1 lib/rumale/linear_model/base_estimator.rb
rumale-linear_model-0.28.0 lib/rumale/linear_model/base_estimator.rb
rumale-linear_model-0.27.0 lib/rumale/linear_model/base_estimator.rb
rumale-linear_model-0.26.0 lib/rumale/linear_model/base_estimator.rb
rumale-linear_model-0.25.0 lib/rumale/linear_model/base_estimator.rb