lib/rumale/pairwise_metric.rb in rumale-0.18.2 vs lib/rumale/pairwise_metric.rb in rumale-0.18.3

- old
+ new

@@ -59,15 +59,16 @@ # @param x [Numo::DFloat] (shape: [n_samples_x, n_features]) # @param y [Numo::DFloat] (shape: [n_samples_y, n_features]) # @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features. # @return [Numo::DFloat] (shape: [n_samples_x, n_samples_x] or [n_samples_x, n_samples_y] if y is given) def rbf_kernel(x, y = nil, gamma = nil) - y = x if y.nil? - gamma ||= 1.0 / x.shape[1] + y_not_given = y.nil? + y = x if y_not_given x = Rumale::Validation.check_convert_sample_array(x) - y = Rumale::Validation.check_convert_sample_array(y) + y = Rumale::Validation.check_convert_sample_array(y) unless y_not_given + gamma ||= 1.0 / x.shape[1] Rumale::Validation.check_params_numeric(gamma: gamma) - Numo::NMath.exp(-gamma * squared_error(x, y).abs) + Numo::NMath.exp(-gamma * squared_error(x, y)) end # Calculate the linear kernel between x and y. # # @param x [Numo::DFloat] (shape: [n_samples_x, n_features])