Machine LearningIntermediate

Decision Tree Classifier in Python

Build a decision tree classifier in Python with scikit-learn. Train, visualize the actual tree, predict, and learn how to avoid overfitting — runnable in your browser.

Try it yourself

Run this code directly in your browser. Click "Open in full editor" to experiment further.

Loading...

Click Run to see output

Or press Ctrl + Enter

How it works

A decision tree is the most intuitive machine learning model that exists — it's literally a flowchart of yes/no questions that lead to a prediction. You can hand the trained tree to someone who's never heard of ML and they can read it.

How A Decision Tree Thinks

Given your data, the algorithm asks: "What's the single best yes/no question I can ask that splits the data into the cleanest groups?" It tries every feature at every possible threshold, picks the winner, then recurses on each side. It keeps going until either the groups are pure (all the same class), or some stopping rule kicks in.

For the iris dataset above, the very first question the tree learns is something like "is petal length less than 2.5cm?" — and that one question alone perfectly separates the setosa flowers from the other two species.

What "Best Split" Means

The tree needs a way to score how clean a split is. The two standard scores:

  • Gini impurity — "if I randomly picked a point from this group and randomly guessed its label using the group's class proportions, how often would I be wrong?" Lower is better.
  • Entropy — same idea, borrowed from information theory. Slightly slower to compute, almost always gives identical trees.
  • In practice, just use Gini. It's the default and it's fine.

    The Single Most Important Knob: `max_depth`

    A decision tree, left unrestrained, will keep splitting until every leaf contains exactly one training point. Training accuracy will be 100%. Test accuracy will be terrible. This is the textbook example of overfitting — the model memorized the training set instead of learning the underlying pattern.

    The overfitting demo at the bottom of the snippet shows it cleanly: an unlimited tree gets 100% on training but worse test accuracy than a depth-3 tree. Always limit `max_depth`, or use one of the other regularizing parameters:

    ParameterWhat it does
    max_depthHard cap on how many questions deep the tree can go
    min_samples_splitDon't split a node unless it has at least this many samples
    min_samples_leafEvery leaf must have at least this many samples
    max_leaf_nodesTotal cap on leaf count
    ccp_alphaCost-complexity pruning — the principled way

    Reading The Visualized Tree

    When you run plot_tree, each box shows:

  • The split rule at the top (e.g. petal width <= 0.8)
  • `gini` — impurity at this node (0 = perfectly pure)
  • `samples` — how many training points reached this node
  • `value` — the count of each class at this node
  • `class` — the majority class (this is what a leaf would predict)
  • The color — which class dominates, with intensity showing confidence
  • Going left = answer was "yes", going right = "no". Trace any flower through the tree by hand and you'll get the same prediction the model gives you.

    Feature Importance — Free Insight

    A trained tree can tell you which features actually mattered. clf.feature_importances_ returns a number per feature (summing to 1) based on how much each feature reduced impurity across the whole tree. For iris, petal measurements dominate — sepal measurements are barely used.

    This alone makes trees worth running even if you plan to deploy a different model: they're a fast way to see which features are pulling their weight.

    Strengths

  • Interpretable — you can show the actual model to a non-technical stakeholder.
  • No scaling needed — splits are based on "is feature greater than threshold", which doesn't care if your data is in dollars or millimeters.
  • Handles mixed data — categorical and numeric features both work (with one-hot encoding for categories in scikit-learn).
  • Captures non-linear patterns without you having to do anything special.
  • Weaknesses

  • High variance — small changes in training data can produce a very different tree. This is exactly why random forests and gradient boosting exist: they average many trees together to smooth out the variance.
  • Axis-aligned splits only — every split is on one feature at a time. Diagonal decision boundaries are approximated by staircases.
  • Easily overfits — if you forget max_depth, you'll get a model that's perfect on training and useless in production.
  • When To Use A Single Tree vs. A Forest

    Use a single decision tree when interpretability matters more than the last few percent of accuracy — medical decisions, regulatory contexts, or just for explaining the model to your team.

    Use a Random Forest or Gradient Boosting when you just want the best accuracy possible. They're built on top of the exact same tree algorithm — they just average lots of them.

    Run the snippet above and you'll see a real flowchart of how the model decides what species a flower is, a feature importance chart that confirms petal size is the giveaway, and a side-by-side of how train/test accuracy diverges as you let the tree grow deeper.

    Related examples