Sha256: 33001319a3f2e3e5f8b2811e16ac5ffd0788494f3d8c6647d9add6589db2f44c

Contents?: true

Size: 1.2 KB

Versions: 5

Compression:

Stored size: 1.2 KB

Contents

module Rblearn
  module CrossValidation
    # x, y: Narray object
    # We slice a matrix by x[Array<Integer>, true]
    def self.train_test_split(x, y, test_size=0.33)
      doc_size = x.shape[0]
      random_indices = (0...doc_size).to_a.shuffle
      endpoint = (doc_size * test_size).to_i
      train_indices = random_indices[endpoint..-1]
      test_indices = random_indices[0...endpoint]

      return [x[train_indices, true], y[train_indices, true], x[test_indices, true], y[test_indices, true]]
    end

    class KFold
      # TODO: make indices and n_folds private

      def initialize(n, n_folds, shuffle)
        indices = (0...n).to_a
        indices.shuffle! if shuffle
        @indices = indices
        @n_folds = n_folds
      end

      def create
        groups_nfolds = @indices.each_slice((@indices.size.to_f / @n_folds).ceil).to_a
        groups = []

        @n_folds.times do |k|
          validation_set = []
          test_set = []

          @n_folds.times do |j|
            test_set += groups_nfolds[j] if k == j
            validation_set += groups_nfolds[j] unless k == j
          end
          groups << [validation_set, test_set]
        end

        return groups
      end
    end
  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
rblearn-0.3.1 lib/rblearn/CrossValidation.rb
rblearn-0.3.0 lib/rblearn/CrossValidation.rb
rblearn-0.2.2 lib/rblearn/CrossValidation.rb
rblearn-0.2.1 lib/rblearn/CrossValidation.rb
rblearn-0.2.0 lib/rblearn/CrossValidation.rb