lib/rumale/preprocessing/one_hot_encoder.rb in rumale-0.18.6 vs lib/rumale/preprocessing/one_hot_encoder.rb in rumale-0.18.7

- old
+ new

@@ -49,10 +49,11 @@ # @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to fit one-hot-encoder. # @return [OneHotEncoder] def fit(x, _y = nil) x = Numo::Int32.cast(x) unless x.is_a?(Numo::Int32) raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any? + @n_values = x.max(0) + 1 @feature_indices = Numo::Int32.hstack([[0], @n_values]).cumsum @active_features = encode(x, @feature_indices).sum(0).ne(0).where self end @@ -65,19 +66,21 @@ # @return [Numo::DFloat] The one-hot-vectors. def fit_transform(x, _y = nil) x = Numo::Int32.cast(x) unless x.is_a?(Numo::Int32) raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any? raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any? + fit(x).transform(x) end # Encode samples into one-hot-vectors. # # @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to encode into one-hot-vectors. # @return [Numo::DFloat] The one-hot-vectors. def transform(x) x = Numo::Int32.cast(x) unless x.is_a?(Numo::Int32) raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any? + codes = encode(x, @feature_indices) codes[true, @active_features].dup end private