Decision Tree Classifier in Python
Build a decision tree classifier in Python with scikit-learn. Train, visualize the actual tree, predict, and learn how to avoid overfitting — runnable in your browser.
Try it yourself
Run this code directly in your browser. Click "Open in full editor" to experiment further.
Click Run to see output
Or press Ctrl + Enter
How it works
A decision tree is the most intuitive machine learning model that exists — it's literally a flowchart of yes/no questions that lead to a prediction. You can hand the trained tree to someone who's never heard of ML and they can read it.
How A Decision Tree Thinks
Given your data, the algorithm asks: "What's the single best yes/no question I can ask that splits the data into the cleanest groups?" It tries every feature at every possible threshold, picks the winner, then recurses on each side. It keeps going until either the groups are pure (all the same class), or some stopping rule kicks in.
For the iris dataset above, the very first question the tree learns is something like "is petal length less than 2.5cm?" — and that one question alone perfectly separates the setosa flowers from the other two species.
What "Best Split" Means
The tree needs a way to score how clean a split is. The two standard scores:
In practice, just use Gini. It's the default and it's fine.
The Single Most Important Knob: `max_depth`
A decision tree, left unrestrained, will keep splitting until every leaf contains exactly one training point. Training accuracy will be 100%. Test accuracy will be terrible. This is the textbook example of overfitting — the model memorized the training set instead of learning the underlying pattern.
The overfitting demo at the bottom of the snippet shows it cleanly: an unlimited tree gets 100% on training but worse test accuracy than a depth-3 tree. Always limit `max_depth`, or use one of the other regularizing parameters:
| Parameter | What it does |
|---|---|
max_depth | Hard cap on how many questions deep the tree can go |
min_samples_split | Don't split a node unless it has at least this many samples |
min_samples_leaf | Every leaf must have at least this many samples |
max_leaf_nodes | Total cap on leaf count |
ccp_alpha | Cost-complexity pruning — the principled way |
Reading The Visualized Tree
When you run plot_tree, each box shows:
petal width <= 0.8)Going left = answer was "yes", going right = "no". Trace any flower through the tree by hand and you'll get the same prediction the model gives you.
Feature Importance — Free Insight
A trained tree can tell you which features actually mattered. clf.feature_importances_ returns a number per feature (summing to 1) based on how much each feature reduced impurity across the whole tree. For iris, petal measurements dominate — sepal measurements are barely used.
This alone makes trees worth running even if you plan to deploy a different model: they're a fast way to see which features are pulling their weight.
Strengths
Weaknesses
max_depth, you'll get a model that's perfect on training and useless in production.When To Use A Single Tree vs. A Forest
Use a single decision tree when interpretability matters more than the last few percent of accuracy — medical decisions, regulatory contexts, or just for explaining the model to your team.
Use a Random Forest or Gradient Boosting when you just want the best accuracy possible. They're built on top of the exact same tree algorithm — they just average lots of them.
Run the snippet above and you'll see a real flowchart of how the model decides what species a flower is, a feature importance chart that confirms petal size is the giveaway, and a side-by-side of how train/test accuracy diverges as you let the tree grow deeper.
Related examples
Logistic Regression in Python
Learn logistic regression in Python with scikit-learn. Binary classification, decision boundary, probabilities, and ROC curve — all explained and runnable in your browser.
Confusion Matrix & Classification Metrics in Python
Understand the confusion matrix in Python with scikit-learn. Precision, recall, F1, and accuracy on a plotted heatmap — runnable in your browser, no setup.
K-Means Clustering in Python
Learn K-Means clustering in Python with scikit-learn. Visualize clusters forming, pick the right K with the elbow method, and run it all in your browser.