lib/rumale/pairwise_metric.rb in rumale-0.17.3 vs lib/rumale/pairwise_metric.rb in rumale-0.18.0

- old
+ new

@@ -40,18 +40,19 @@ # # @param x [Numo::DFloat] (shape: [n_samples_x, n_features]) # @param y [Numo::DFloat] (shape: [n_samples_y, n_features]) # @return [Numo::DFloat] (shape: [n_samples_x, n_samples_x] or [n_samples_x, n_samples_y] if y is given) def squared_error(x, y = nil) - y = x if y.nil? + y_not_given = y.nil? + y = x if y_not_given x = Rumale::Validation.check_convert_sample_array(x) - y = Rumale::Validation.check_convert_sample_array(y) - n_features = x.shape[1] - one_vec = Numo::DFloat.ones(n_features).expand_dims(1) - sum_x_vec = (x**2).dot(one_vec) - sum_y_vec = (y**2).dot(one_vec).transpose - dot_xy_mat = x.dot(y.transpose) - dot_xy_mat * -2.0 + sum_x_vec + sum_y_vec + y = Rumale::Validation.check_convert_sample_array(y) unless y_not_given + sum_x_vec = (x**2).sum(1).expand_dims(1) + sum_y_vec = y_not_given ? sum_x_vec.transpose : (y**2).sum(1).expand_dims(1).transpose + err_mat = -2 * x.dot(y.transpose) + err_mat += sum_x_vec + err_mat += sum_y_vec + err_mat.class.maximum(err_mat, 0) end # Calculate the rbf kernel between x and y. # # @param x [Numo::DFloat] (shape: [n_samples_x, n_features])