lib/svmkit/preprocessing/standard_scaler.rb in svmkit-0.2.7 vs lib/svmkit/preprocessing/standard_scaler.rb in svmkit-0.2.8

- old
+ new

@@ -37,10 +37,11 @@ # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) # The samples to calculate the mean values and standard deviations. # @return [StandardScaler] def fit(x, _y = nil) + SVMKit::Validation.check_sample_array(x) @mean_vec = x.mean(0) @std_vec = x.stddev(0) self end @@ -50,17 +51,19 @@ # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) # The samples to calculate the mean values and standard deviations. # @return [Numo::DFloat] The scaled samples. def fit_transform(x, _y = nil) + SVMKit::Validation.check_sample_array(x) fit(x).transform(x) end # Perform standardization the given samples. # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be scaled. # @return [Numo::DFloat] The scaled samples. def transform(x) + SVMKit::Validation.check_sample_array(x) n_samples, = x.shape (x - @mean_vec.tile(n_samples, 1)) / @std_vec.tile(n_samples, 1) end # Dump marshal data.