lib/svmkit/preprocessing/l2_normalizer.rb in svmkit-0.2.7 vs lib/svmkit/preprocessing/l2_normalizer.rb in svmkit-0.2.8

- old
+ new

@@ -30,10 +30,11 @@ # @overload fit(x) -> L2Normalizer # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate L2-norms. # @return [L2Normalizer] def fit(x, _y = nil) + SVMKit::Validation.check_sample_array(x) @norm_vec = Numo::NMath.sqrt((x**2).sum(1)) self end # Calculate L2-norms of each sample, and then normalize samples to unit L2-norm. @@ -41,9 +42,10 @@ # @overload fit_transform(x) -> Numo::DFloat # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate L2-norms. # @return [Numo::DFloat] The normalized samples. def fit_transform(x, _y = nil) + SVMKit::Validation.check_sample_array(x) fit(x) x / @norm_vec.tile(x.shape[1], 1).transpose end end end