lib/svmkit/clustering/k_means.rb in svmkit-0.5.1 vs lib/svmkit/clustering/k_means.rb in svmkit-0.5.2

- old
+ new

@@ -7,14 +7,15 @@ module SVMKit # This module consists of classes that implement cluster analysis methods. module Clustering # KMeans is a class that implements K-Means cluster analysis. + # The current implementation uses the Euclidean distance for analyzing the clusters. # # @example # analyzer = SVMKit::Clustering::KMeans.new(n_clusters: 10, max_iter: 50) - # cluster_ids = analyzer.fit_predict(samples) + # cluster_labels = analyzer.fit_predict(samples) # # *Reference* # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007. class KMeans include Base::BaseEstimator @@ -36,10 +37,11 @@ # @param max_iter [Integer] The maximum number of iterations. # @param tol [Float] The tolerance of termination criterion. # @param random_seed [Integer] The seed value using to initialize the random generator. def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) check_params_integer(n_clusters: n_clusters, max_iter: max_iter) + check_params_float(tol: tol) check_params_string(init: init) check_params_type_or_nil(Integer, random_seed: random_seed) check_params_positive(n_clusters: n_clusters, max_iter: max_iter) @params = {} @params[:n_clusters] = n_clusters @@ -60,34 +62,34 @@ # @return [KMeans] The learned cluster analyzer itself. def fit(x, _y = nil) check_sample_array(x) init_cluster_centers(x) @params[:max_iter].times do |_t| - cluster_ids = assign_cluster(x) + cluster_labels = assign_cluster(x) old_centers = @cluster_centers.dup @params[:n_clusters].times do |n| - assigned_bits = cluster_ids.eq(n) + assigned_bits = cluster_labels.eq(n) @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count > 0 end error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean break if error <= @params[:tol] end self end - # Predict cluster indices for samples. + # Predict cluster labels for samples. # - # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster index. - # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster index per sample. + # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label. + # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample. def predict(x) check_sample_array(x) assign_cluster(x) end # Analysis clusters and assign samples to clusters. # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis. - # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster index per sample. + # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample. def fit_predict(x) check_sample_array(x) fit(x) predict(x) end