Machine LearningIntermediate

K-Means Clustering in Python

Learn K-Means clustering in Python with scikit-learn. Visualize clusters forming, pick the right K with the elbow method, and run it all in your browser.

Try it yourself

Run this code directly in your browser. Click "Open in full editor" to experiment further.

Loading...

Click Run to see output

Or press Ctrl + Enter

How it works

K-Means is the most popular clustering algorithm in the world, and for good reason: the idea is simple, it's blazingly fast, and the results are easy to explain to anyone — even people who don't know what machine learning is.

What "Clustering" Even Means

Clustering is unsupervised learning — you give the algorithm a pile of points and it tries to figure out how they naturally group together, without anyone telling it what the right answer looks like. There are no labels, no "correct" outputs to learn from. Just the geometry of the data.

Use cases you've definitely encountered:

  • Customer segmentation ("these 10,000 buyers fall into 5 spending patterns")
  • Image color quantization (reducing a photo to its 16 dominant colors)
  • Document grouping ("all news articles about the same event")
  • Anomaly detection (anything that doesn't fit any cluster is suspicious)
  • How K-Means Actually Works

    The algorithm is so simple you can explain it on a napkin:

    1. Pick K — decide how many clusters you want.

    2. Drop K random points somewhere in your data — these are your initial "centroids".

    3. Assign every data point to its nearest centroid.

    4. Move each centroid to the average position of the points assigned to it.

    5. Repeat steps 3 and 4 until the centroids stop moving.

    That's it. No gradients, no probabilities, no neural networks. Just "find center, assign points, move center, repeat".

    The Hardest Part: Picking K

    K-Means won't tell you how many clusters are in your data — you have to choose. The standard trick is the elbow method:

  • Run K-Means with K = 1, 2, 3, ... 10.
  • For each one, record the inertia — the total squared distance from each point to its centroid. Lower is tighter.
  • Plot inertia vs. K. The curve always goes down (more clusters = tighter fit), but at some point the improvement flattens out. That bend is the "elbow".
  • Pick K at the elbow.
  • In the snippet above, the elbow lands cleanly at K=4 — which matches how the data was actually generated.

    For harder cases there's also the silhouette score (sklearn.metrics.silhouette_score), which gives every choice of K a single number — pick the K with the highest score.

    Things K-Means Is Bad At

    It's a great default, but it has well-known weaknesses:

  • Non-spherical clusters — K-Means assumes clusters are blobs. Long, snake-like, or ring-shaped clusters confuse it. Try DBSCAN instead.
  • Wildly different cluster sizes — it tends to chop big clusters and merge small ones.
  • Different scales — if one feature ranges 0–1 and another ranges 0–10000, the second feature dominates the distance calculation. Always run StandardScaler first if your features are on different scales.
  • Outliers pull centroids around — a few extreme points can drag a centroid far from where it should be. Consider KMedoids if outliers are a problem.
  • Random initialization can give bad results — that's why we set n_init=10. It runs the whole thing 10 times with different random starts and keeps the best one. Always do this.
  • The API You Actually Need

    from sklearn.cluster import KMeans
    
    kmeans = KMeans(n_clusters=4, n_init=10, random_state=42)
    kmeans.fit(X)
    
    kmeans.labels_           # cluster index for each training point
    kmeans.cluster_centers_  # coordinates of each centroid
    kmeans.inertia_          # tightness score (lower is better)
    kmeans.predict(X_new)    # which cluster does a NEW point belong to?

    A Few Pro Tips

  • Always set random_state so your results are reproducible.
  • Always set n_init=10 (in newer scikit-learn it defaults to 'auto' but explicit is better).
  • Standardize your features before clustering: StandardScaler().fit_transform(X).
  • For very large datasets (>10k points), use MiniBatchKMeans — same idea, much faster.
  • The predict() method is what makes K-Means useful in production: train on historical data, then assign new incoming points to existing clusters in real time.
  • Run the snippet above and you'll see four clean clusters get discovered automatically, an elbow chart pointing at the right K, and the model assigning brand-new points to the cluster they obviously belong to.

    Related examples