lib/rumale/pairwise_metric.rb in rumale-0.17.3 vs lib/rumale/pairwise_metric.rb in rumale-0.18.0
- old
+ new
@@ -40,18 +40,19 @@
#
# @param x [Numo::DFloat] (shape: [n_samples_x, n_features])
# @param y [Numo::DFloat] (shape: [n_samples_y, n_features])
# @return [Numo::DFloat] (shape: [n_samples_x, n_samples_x] or [n_samples_x, n_samples_y] if y is given)
def squared_error(x, y = nil)
- y = x if y.nil?
+ y_not_given = y.nil?
+ y = x if y_not_given
x = Rumale::Validation.check_convert_sample_array(x)
- y = Rumale::Validation.check_convert_sample_array(y)
- n_features = x.shape[1]
- one_vec = Numo::DFloat.ones(n_features).expand_dims(1)
- sum_x_vec = (x**2).dot(one_vec)
- sum_y_vec = (y**2).dot(one_vec).transpose
- dot_xy_mat = x.dot(y.transpose)
- dot_xy_mat * -2.0 + sum_x_vec + sum_y_vec
+ y = Rumale::Validation.check_convert_sample_array(y) unless y_not_given
+ sum_x_vec = (x**2).sum(1).expand_dims(1)
+ sum_y_vec = y_not_given ? sum_x_vec.transpose : (y**2).sum(1).expand_dims(1).transpose
+ err_mat = -2 * x.dot(y.transpose)
+ err_mat += sum_x_vec
+ err_mat += sum_y_vec
+ err_mat.class.maximum(err_mat, 0)
end
# Calculate the rbf kernel between x and y.
#
# @param x [Numo::DFloat] (shape: [n_samples_x, n_features])