Implementing a Decision Tree Classifier from Scratch in Python

Lori Schlatter
5 min readJul 31, 2020
A single barren tree in the distance under a blue sky
Photo by Luke Richardson on Unsplash

Decision trees and random forests are one of the first types of models a budding data scientist may learn how to implement. They’re such a useful tool in part because of how interpretable their results are, but to take this understanding to a deeper level, I chose to implement a decision tree classifier ‘by hand’ in Python for a project. Python code located in this GitHub repo.

First, to briefly explain a decision tree in case you’re not familiar. The easiest way to think of a decision tree is to think of a flow chart. One question is asked, and then based on a “yes” or “no” answer, a different question is asked, and so on. Decision trees use this method to sort and find patterns in a given dataset in an effort to learn how to classify individual observations. Then after learning which questions should be asked at which point (or node), a tree can use this pattern to make predictions on new data. See below for a simple example of the structure of a decision tree.

Simple example of the structure of a decision tree.

In order to start this implementation, I looked for the at scikit-learn’s source code for their DecisionTreeClassifier, which is a class built on top of both the ClassifierMixin and BaseDecisionTree classes. In the spirit of standing on the shoulders of giants, I also looked around for other implementations, and found a Machine Learning Mastery article by Jason Brownlee, which ultimately helped guide me through a baseline implementation to adapt.

The main questions I learned that need to be asked to implement a decision tree are:

  • What is the best feature to split the data on?
  • Within that feature, what is the best split point?
  • And repeat…

As other developers might guess, the “repeat” portion of this bullet point makes a decision tree a great candidate for recursion. But how do we decide which features to split on?

The key metric or criterion used in most decision trees (and hard coded into my implementation) to determine which split points are “best” is called Gini impurity. An important piece of the project for me was gaining a better understanding of how this metric works.

Essentially, Gini impurity is a number (a float between 0 and 1) that provides the probability of an incorrect classification. Therefore, a lower number is better. Think about this: if we create a decision tree, such that the lowermost leaves/nodes only contain one type of classification, then it will always correctly predict which classification should belong in that leaf. Here is an example calculation formula for Gini impurity:

Gini impurity of a given leaf, in psuedocode, given 2 classes:probability(class) = # of class in leaf / total # of observations in leafgini = 1 - (probability(class1))^2 - (probability(class2))^2Then, to weight the score appropriately based on number of observations at the given split point :weighted_gini = gini * (observations in given split / total observations at node)

Of course, having leaves with only one classification each isn’t always the ideal outcome, because we want to be careful not to overfit our tree to our training data, and learn only how to classify those specific examples. We want the tree to be able to also generalize to new data. This is also why the parameters of max_depth and min_samples_leaf are included in this implementation. max_depth is a key parameter for decision trees, and stops the tree from going beyond a given depth of nodes, which can lead to overfitting. min_samples_leaf prevents a leaf from becoming a parent node and splitting again, if the current number of observations of data are below that threshold. Both of these parameters come into play in the split method of this implementation.

So how does this all come together?

The methods I implemented within the class are described as follows:

  • gini_impurity: calculates the Gini impurity of a given potential data split
  • to_leaf: a function to change a node into a leaf (i.e. no further splitting)
  • find_split: sorts data at a given split point based on proposed split value. Returns left and right node values.
  • evaluate_split: uses the gini_impurity function and the find_split function to calculate the best split point, by evaluating every possible one. Returns a dictionary with the best values.
  • split: takes in a node and uses to_leaf and recursion to continually split data both left and right of the initial node.
  • fit: this is the first method that is directly used. It takes in a dataset in the form of a numeric list of lists (with numeric or string classifications positioned as the final value of each row), and uses the evaluate_split and split functions to fit and create the tree. It returns the root node of the tree, which is linked to all the subsequent child nodes (left and right).
  • predict: this function takes in the trained root node and each individual row, to predict where each individual value would be located on the tree, using recursion until the correct leaf is found. This must be run iteratively to find multiple predictions.
  • accuracy:this function uses the predict function and by hand calculations to find the accuracy of predictions against the actual values provided

To check out the actual code implementations, documentation, and step-by-step example code implementation, check out my GitHub repo here.

When compared to the sklearn implementation on the classic UCI Iris Dataset, this classifier actually performed very similarly, both getting an average test accuracy of around 95%, with max_depth=5 and min_samples_split=4 for both.

Essentially, throughout the week I worked on this exercise, I was able to learn how to better read source code, find another implementation to work off of, and build it into my own custom Decision Tree Classifier class. I adapted the code so that it was object-oriented and the parameters could be passed in the implementation rather than the methods, added the accuracy method, added documentation, and implemented the classifier on an actual dataset in order to compare to sklearn’s implementation.

To further build out the implementation, the next things I would add would be:

  • adaptation for different types of data
  • adaptation for a DecisionTreeRegressor
  • option for additional parameters to match sklearn implementation

And I know that next time I type from sklearn.tree import DecisionTreeClassifier it will be with a newfound appreciation for its complexity.

--

--