lib/svmkit/clustering/k_means.rb in svmkit-0.5.1 vs lib/svmkit/clustering/k_means.rb in svmkit-0.5.2
- old
+ new
@@ -7,14 +7,15 @@
module SVMKit
# This module consists of classes that implement cluster analysis methods.
module Clustering
# KMeans is a class that implements K-Means cluster analysis.
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
#
# @example
# analyzer = SVMKit::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
- # cluster_ids = analyzer.fit_predict(samples)
+ # cluster_labels = analyzer.fit_predict(samples)
#
# *Reference*
# - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
class KMeans
include Base::BaseEstimator
@@ -36,10 +37,11 @@
# @param max_iter [Integer] The maximum number of iterations.
# @param tol [Float] The tolerance of termination criterion.
# @param random_seed [Integer] The seed value using to initialize the random generator.
def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
+ check_params_float(tol: tol)
check_params_string(init: init)
check_params_type_or_nil(Integer, random_seed: random_seed)
check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
@params = {}
@params[:n_clusters] = n_clusters
@@ -60,34 +62,34 @@
# @return [KMeans] The learned cluster analyzer itself.
def fit(x, _y = nil)
check_sample_array(x)
init_cluster_centers(x)
@params[:max_iter].times do |_t|
- cluster_ids = assign_cluster(x)
+ cluster_labels = assign_cluster(x)
old_centers = @cluster_centers.dup
@params[:n_clusters].times do |n|
- assigned_bits = cluster_ids.eq(n)
+ assigned_bits = cluster_labels.eq(n)
@cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count > 0
end
error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
break if error <= @params[:tol]
end
self
end
- # Predict cluster indices for samples.
+ # Predict cluster labels for samples.
#
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster index.
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster index per sample.
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
def predict(x)
check_sample_array(x)
assign_cluster(x)
end
# Analysis clusters and assign samples to clusters.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster index per sample.
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
def fit_predict(x)
check_sample_array(x)
fit(x)
predict(x)
end