K-Means clustering is a widely used algorithm for partitioning a dataset into a pre-determined number of distinct, non-overlapping subgroups, or clusters. It's an iterative algorithm that aims to find local optima by minimizing the sum of squared distances from each data point to the centroid, or arithmetic mean, of its assigned cluster. This approach is effective for identifying underlying group structures in unlabeled data, which is a common task in unsupervised learning.
The K-Means algorithm operates through a straightforward iterative process to assign each data point to one of k clusters. The core idea is to refine cluster centroids and point assignments until a stable configuration is reached.
The main steps are:
The iterative refinement process of K-Means is illustrated below:
The K-Means algorithm iteratively assigns data points to clusters and updates cluster centroids until convergence.
The objective function that K-Means attempts to minimize is the Within-Cluster Sum of Squares (WCSS), also known as inertia: WCSS=∑j=1k∑xi∈Sj∣∣xi−cj∣∣2 where k is the number of clusters, Sj is the set of points in cluster j, and cj is the centroid of cluster j.
Clustering.jl
In Julia, the Clustering.jl
package provides an efficient implementation of K-Means and other clustering algorithms. If you haven't installed it yet, you can add it using Julia's package manager:
using Pkg
Pkg.add("Clustering")
Pkg.add("Plots") # For visualization
Pkg.add("Random") # For generating sample data
Let's walk through a basic example. First, we'll generate some synthetic 2D data that has a natural grouping.
using Clustering, Plots, Random
# Set a seed for reproducibility
Random.seed!(1234)
# Generate synthetic data with three distinct groups
# Group 1
X1 = randn(2, 50) .* 0.5 .+ [2.0; 2.0]
# Group 2
X2 = randn(2, 50) .* 0.5 .+ [4.0; 4.0]
# Group 3
X3 = randn(2, 50) .* 0.5 .+ [3.0; 0.0]
# Combine the groups into a single dataset
# Clustering.jl expects data where columns are features and rows are observations
X_combined = hcat(X1, X2, X3)
data_points = permutedims(X_combined) # Transpose so observations are rows
# Perform K-Means clustering
# We specify k=3 because we know our synthetic data has three groups
k = 3
result = kmeans(data_points', k; display=:iter) # data_points' is features as columns
# Accessing the results
assignments = result.assignments # Cluster assignment for each point
centroids = result.centers # Coordinates of the cluster centroids (features as rows)
total_cost = result.totalcost # WCSS for the final clustering
iterations = result.iterations # Number of iterations performed
converged = result.converged # Boolean indicating if the algorithm converged
println("Number of points: ", size(data_points, 1))
println("Cluster assignments (first 10): ", assignments[1:10])
println("Centroids:\n", permutedims(centroids)) # Display centroids with features as columns
println("Total WCSS: ", total_cost)
println("Iterations: ", iterations)
println("Converged: ", converged)
In the kmeans
function, data_points'
means we provide the data with features as columns and observations as rows, which is a common convention for some Julia ML packages. The display=:iter
option shows the progress of the algorithm. The result
object is a KmeansResult
struct containing detailed information about the clustering outcome.
One of the main challenges with K-Means is that you need to specify the number of clusters, k, beforehand. In many practical scenarios, the optimal k is unknown. Several methods can help guide this decision:
Let's calculate WCSS for different values of k using our synthetic data:
# data_points' is features as columns, observations as rows
data_for_kmeans = data_points' # from previous example
max_k = 10
wcss_values = Float64[]
for k_val in 1:max_k
r = kmeans(data_for_kmeans, k_val)
push!(wcss_values, r.totalcost)
end
# The Plotly chart below visualizes this.
# You can use Plots.jl for a quick plot too:
# plot(1:max_k, wcss_values, xlabel="Number of Clusters (k)", ylabel="WCSS", marker=:o, legend=false)
The WCSS typically decreases as k increases. The "elbow" point (around k=3 in this example) suggests a suitable number of clusters where adding more clusters yields diminishing returns.
Silhouette Analysis: This method measures how similar a data point is to its own cluster compared to other clusters. The silhouette score ranges from -1 to 1. Higher values indicate well-separated clusters. We will cover evaluation metrics like the Silhouette score in more detail later in this chapter.
Domain Knowledge: Often, prior knowledge about the data or the problem context can suggest a natural number of clusters.
Clustering.jl
's kmeans
function, by default, uses the "k-means++" initialization strategy. This method generally leads to better and more consistent results than purely random initialization. You can control aspects of initialization (e.g., init=:kmcen
for k-means central selection or provide your own initial centroids).Clustering.jl
's default (k-means++) often makes this less necessary, but the nruns
parameter can be used if desired.Visualizing the clusters can provide valuable insights, especially for 2D or 3D data. Let's plot our 3-cluster result from the earlier example using Plots.jl
.
# Using results from the k=3 example:
# data_points (observations as rows, features as columns)
# assignments (cluster id for each point)
# centroids (features as rows, clusters as columns)
# Prepare data for plotting
x_coords = data_points[:, 1]
y_coords = data_points[:, 2]
# Centroids (needs to be transposed for plotting if features are rows)
# result.centers has features as rows, clusters as columns.
# So, centroids_plot is clusters as rows, features (coords) as columns.
centroids_plot = permutedims(result.centers)
# Create a scatter plot
p = scatter(x_coords, y_coords, group=assignments,
xlabel="Feature 1", ylabel="Feature 2",
title="K-Means Clustering (k=3)",
legend=:outertopright, palette=:viridis) # :viridis is just one option
# Add centroids to the plot
scatter!(p, centroids_plot[:, 1], centroids_plot[:, 2],
markershape=:xcross, markersize=8, markercolor=:red,
label="Centroids", seriesalpha=1)
# To display the plot in a typical Julia environment:
# display(p)
# If using a notebook or environment that shows plots automatically, this might not be needed.
Example of K-Means clustering on 2D data. Points are colored by their assigned cluster, and cluster centroids are marked with red crosses. The
Plots.jl
library with a backend like GR or PlotlyJS can generate such visualizations.
K-Means is popular for good reasons, but it's important to be aware of its characteristics:
Strengths:
Limitations:
Despite its limitations, K-Means is often a good starting point for clustering tasks due to its simplicity and efficiency. Understanding its behavior and assumptions is important for applying it effectively and interpreting its results. Later in this chapter, we will explore DBSCAN, a density-based clustering algorithm that can address some of these limitations, such as discovering clusters of arbitrary shapes and not requiring k to be specified upfront.
Was this section helpful?
© 2025 ApX Machine Learning