lib/svmkit/preprocessing/standard_scaler.rb in svmkit-0.2.7 vs lib/svmkit/preprocessing/standard_scaler.rb in svmkit-0.2.8
- old
+ new
@@ -37,10 +37,11 @@
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
# The samples to calculate the mean values and standard deviations.
# @return [StandardScaler]
def fit(x, _y = nil)
+ SVMKit::Validation.check_sample_array(x)
@mean_vec = x.mean(0)
@std_vec = x.stddev(0)
self
end
@@ -50,17 +51,19 @@
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
# The samples to calculate the mean values and standard deviations.
# @return [Numo::DFloat] The scaled samples.
def fit_transform(x, _y = nil)
+ SVMKit::Validation.check_sample_array(x)
fit(x).transform(x)
end
# Perform standardization the given samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be scaled.
# @return [Numo::DFloat] The scaled samples.
def transform(x)
+ SVMKit::Validation.check_sample_array(x)
n_samples, = x.shape
(x - @mean_vec.tile(n_samples, 1)) / @std_vec.tile(n_samples, 1)
end
# Dump marshal data.