Sha256: fa0723eebf2b4171601a7699bce7dce8756f7d1ae81492bd97bbeb11548371c7

Contents?: true

Size: 1.89 KB

Versions: 7

Compression:

Stored size: 1.89 KB

Contents

# frozen_string_literal: true

require 'rumale/base/estimator'
require 'rumale/base/transformer'
require 'rumale/validation'

module Rumale
  module Preprocessing
    # Binarize samples according to a threshold
    #
    # @example
    #   require 'rumale/preprocessing/binarizer'
    #
    #   binarizer = Rumale::Preprocessing::Binarizer.new
    #   x = Numo::DFloat[[-1.2, 3.2], [2.4, -0.5], [4.5, 0.8]]
    #   b = binarizer.transform(x)
    #   p b
    #
    #   # Numo::DFloat#shape=[3, 2]
    #   # [[0, 1],
    #   #  [1, 0],
    #   #  [1, 1]]
    class Binarizer < ::Rumale::Base::Estimator
      include ::Rumale::Base::Transformer

      # Create a new transformer for binarization.
      # @param threshold [Float] The threshold value for binarization.
      def initialize(threshold: 0.0)
        super()
        @params = { threshold: threshold }
      end

      # This method does nothing and returns the object itself.
      # For compatibility with other transformer, this method exists.
      #
      # @overload fit() -> Binarizer
      #
      # @return [Binarizer]
      def fit(_x = nil, _y = nil)
        self
      end

      # Binarize each sample.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be binarized.
      # @return [Numo::DFloat] The binarized samples.
      def transform(x)
        x = ::Rumale::Validation.check_convert_sample_array(x)

        x.class.cast(x.gt(@params[:threshold]))
      end

      # The output of this method is the same as that of the transform method.
      # For compatibility with other transformer, this method exists.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be binarized.
      # @return [Numo::DFloat] The binarized samples.
      def fit_transform(x, _y = nil)
        x = ::Rumale::Validation.check_convert_sample_array(x)

        fit(x).transform(x)
      end
    end
  end
end

Version data entries

7 entries across 7 versions & 1 rubygems

Version Path
rumale-preprocessing-0.29.0 lib/rumale/preprocessing/binarizer.rb
rumale-preprocessing-0.28.1 lib/rumale/preprocessing/binarizer.rb
rumale-preprocessing-0.28.0 lib/rumale/preprocessing/binarizer.rb
rumale-preprocessing-0.27.0 lib/rumale/preprocessing/binarizer.rb
rumale-preprocessing-0.26.0 lib/rumale/preprocessing/binarizer.rb
rumale-preprocessing-0.25.0 lib/rumale/preprocessing/binarizer.rb
rumale-preprocessing-0.24.0 lib/rumale/preprocessing/binarizer.rb