lib/rumale/pairwise_metric.rb in rumale-0.18.2 vs lib/rumale/pairwise_metric.rb in rumale-0.18.3
- old
+ new
@@ -59,15 +59,16 @@
# @param x [Numo::DFloat] (shape: [n_samples_x, n_features])
# @param y [Numo::DFloat] (shape: [n_samples_y, n_features])
# @param gamma [Float] The parameter of rbf kernel, if nil it is 1 / n_features.
# @return [Numo::DFloat] (shape: [n_samples_x, n_samples_x] or [n_samples_x, n_samples_y] if y is given)
def rbf_kernel(x, y = nil, gamma = nil)
- y = x if y.nil?
- gamma ||= 1.0 / x.shape[1]
+ y_not_given = y.nil?
+ y = x if y_not_given
x = Rumale::Validation.check_convert_sample_array(x)
- y = Rumale::Validation.check_convert_sample_array(y)
+ y = Rumale::Validation.check_convert_sample_array(y) unless y_not_given
+ gamma ||= 1.0 / x.shape[1]
Rumale::Validation.check_params_numeric(gamma: gamma)
- Numo::NMath.exp(-gamma * squared_error(x, y).abs)
+ Numo::NMath.exp(-gamma * squared_error(x, y))
end
# Calculate the linear kernel between x and y.
#
# @param x [Numo::DFloat] (shape: [n_samples_x, n_features])