lib/rumale/preprocessing/one_hot_encoder.rb in rumale-0.18.6 vs lib/rumale/preprocessing/one_hot_encoder.rb in rumale-0.18.7
- old
+ new
@@ -49,10 +49,11 @@
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to fit one-hot-encoder.
# @return [OneHotEncoder]
def fit(x, _y = nil)
x = Numo::Int32.cast(x) unless x.is_a?(Numo::Int32)
raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any?
+
@n_values = x.max(0) + 1
@feature_indices = Numo::Int32.hstack([[0], @n_values]).cumsum
@active_features = encode(x, @feature_indices).sum(0).ne(0).where
self
end
@@ -65,19 +66,21 @@
# @return [Numo::DFloat] The one-hot-vectors.
def fit_transform(x, _y = nil)
x = Numo::Int32.cast(x) unless x.is_a?(Numo::Int32)
raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any?
raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any?
+
fit(x).transform(x)
end
# Encode samples into one-hot-vectors.
#
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to encode into one-hot-vectors.
# @return [Numo::DFloat] The one-hot-vectors.
def transform(x)
x = Numo::Int32.cast(x) unless x.is_a?(Numo::Int32)
raise ArgumentError, 'Expected the input samples only consists of non-negative integer values.' if x.lt(0).any?
+
codes = encode(x, @feature_indices)
codes[true, @active_features].dup
end
private