lib/rumale/clustering/single_linkage.rb in rumale-0.18.6 vs lib/rumale/clustering/single_linkage.rb in rumale-0.18.7
- old
+ new
@@ -52,10 +52,11 @@
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
# @return [SingleLinkage] The learned cluster analyzer itself.
def fit(x, _y = nil)
x = check_convert_sample_array(x)
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
+
fit_predict(x)
self
end
# Analysis clusters and assign samples to clusters.
@@ -64,9 +65,10 @@
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
def fit_predict(x)
x = check_convert_sample_array(x)
raise ArgumentError, 'Expect the input distance matrix to be square.' if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
+
distance_mat = @params[:metric] == 'precomputed' ? x : Rumale::PairwiseMetric.euclidean_distance(x)
@labels = partial_fit(distance_mat)
end
private