– Dad, you're always on that computer. Maybe we should go for a walk as a family? Jagódka looked out the window – or maybe not. It's not raining, but it's cold. Let's stay home and play board games!

– Yay!! Board games! – Otylka shouted with joy.

– Give me 5 minutes. I'll just finish creating the decision tree – I replied.

- A tree? You don't have any trees on your screen, Daddy.

- Hmm... A decision tree is a kind of plan that helps us make a decision. You've just created such a tree yourself - I smiled.

– Really? – Jagoda looked in disbelief at Otylka.

- Yes! You wanted to go for a walk. At first you looked out the window and checked if it was raining or not. Then you followed the appropriate branches, analyzing other elements, such as temperature. And you got the answer whether to go for a walk.

– So a decision tree is a plan that helps us make good decisions!

– Exactly, Jagódka! It's a tool that helps us understand what factors influence our decisions and what consequences may result from our choices.

A decision tree is a powerful tool in the field of artificial intelligence and machine learning, which is used in various areas, from data classification to solving decision problems. I invite you to a fascinating journey through the world of decision trees, one of the simplest, but despite its simplicity, very useful tools that we have in the arsenal of machine learning.

What is a decision tree?

In simple words, a decision tree is a type of plan or map that helps in making simple decisions based on various options and conditions.

In other words, a decision tree is a graphical structure in which each node represents a decision and each branch represents a possible outcome of that decision.

Above I have visualized a decision tree that was built in Jagódka's head based on her previous walks and her subjective experience. To make it easier for us to discuss the decision tree in the context of machine learning, let's use another example - a more business one.

Our example

Let our task be to classify bank customers, whether a given person will repay the loan or will have problems repaying the money with interest.

We will treat each person in our set as a separate observation. The people in our set will be described by different characteristics. Our goal is to separate people who will repay the loan from those who will have a problem with it by asking questions based on the data we have.

Let's say we have 100 people, of whom 50 have repaid their loan and 50 have not. We also have the following information:

a) have you ever had problems with repayment in the last 3 years (Yes/No),

b) average systematic earnings over the last six months (e.g. 3,000, 5,000, etc.).

Then our tree could look like this:

Decision Tree Terminology

First, it is worth discussing some terminology used in decision trees:

So, according to the above terminology, our tree looks like this:

Decision Tree Creation Algorithm

The purpose of a decision tree is to segregate objects into classes (in our case: will repay / will not repay), and an ideal decision tree (i.e. prediction model in this case) would be one that creates final leaves containing only one class.

Okay, but how does the decision tree know what feature to take and how to arrange the data into a decision tree?

It depends on the algorithm chosen for the decision tree creation process. The most popular algorithms are:

  1. ID3 ( Iterative Dichotomiser 3 ) – this is one of the oldest algorithms for building decision trees, developed by Ross Quinlan in 1986. ID3 works by iteratively dividing a data set into subsets, selecting attributes that best separate the data based on some measure of impurity, e.g. entropy.
  2. C4.5 – This is an improved version of the ID3 algorithm, also developed by Ross Quinlan . C4.5 introduces several improvements over ID3, such as handling missing data, handling numeric attributes, and generating decision rules.
  3. CART ( Classification and Regression Trees ) - The CART algorithm was proposed by Leo Breiman in 1984. It is a more programmatic implementation. It consists of building a binary tree (each node can have a maximum of two children) and works on the principle of dividing the data set into subsets minimizing the Gini coefficient impurity measure. This coefficient is a nice alternative to entropy, because it is easier to calculate from a computer perspective.
  4. CHAID ( Chi-squared Automatic Interaction Detection ) – The CHAID algorithm, developed by Gordon Kassen, is primarily used in statistical analysis. It works by dividing a data set into subsets, using a chi-squared test to determine the significance of the relationship between attributes and the target variable.

It is worth emphasizing that the CART implementation is most often used in algorithms, e.g. in the sklearn or xgboost libraries.

How does the CART algorithm find the best split?

In fact, in CART we can use several methods to find the best splits for our data. The most well-known are Gini impurity and entropy. Both methods are used to evaluate the quality of the split of the data set. However, there are slight differences between them.

Gini impurity

Consider the situation of a credit analyst who is tasked with dividing 100 people into two groups: those who will repay a loan and those who may have a problem doing so. His goal is to divide them so that the group of people standing next to each other (say, in the same room) ultimately contains only people who will repay the loan or not. This way, we will invite one group to the bank and, unfortunately, we will not accept the other.

When sorting people into two groups, the analyst tries to minimize the mixing of people in each group. Gini impurity in this case measures how often, when drawing two people from different groups, we get different outcomes. If one group contains mostly people who will repay the loan and the other mostly people who might have a problem with it, then the Gini impurity will be low because the chances of selecting a person who will repay the loan from one group and a person who will not repay from the other group will be high (close to 1). And after all, 1-something close to one gives a value close to zero.

However, if the groups are mixed, meaning that one group contains both people who will and will not repay the loan, then the Gini impurity will be high (1 minus low probability). This means that the chances of selecting a person who will repay the loan or those who will not from a given group will be similar.

Gini impurity counting example

Let's look at our example on the left side of the node. For the numbers we see there, we can calculate that:

Now let's enumerate the last leaves on the left branch of our tree.

And this is where the magic happens when we provide a categorical or continuous variable for modeling, because the algorithm itself looks for the optimal division of the given variable. So our question "systematic influences ≥3k+?" will be compared with influences > 5k+, 10k+, etc.

And how do we compare and find the best partition of our variable (or variables, if we have more)? We do this by calculating the Gini impurity of each of these parts and sum them, taking into account the proportion of elements in each of them. Then we choose the partition that minimizes this sum, meaning that it best segregates the data by class.

So for us it looks like this:

What if we assumed a different earnings value, e.g. 5,000?

For the division by earnings > 5,000, the weighted sum is higher than for the division by earnings of 3,000, so we adopt that division.

Note. It is worth remembering that if the Gini impurity for two child nodes is not lower than the Gini impurity for the parent node, the algorithm will stop looking for splits.

And that's it – simple, right?

Entropy

Let's go back to the analyst who has to divide the customers. We have divided the customers into two groups and we start to wonder how mixed up they are.

If in one of the groups after division there are more or less the same number of people who will repay the loan and will not repay it and we would like to randomly select one person from this group, then if they are evenly distributed, then entropy will be high. Why? Because we are not sure which person we will find (the one who will repay or the one who will not repay).

However, if most people in a group repay the loan and only a few raisins have problems, the entropy will be lower because we are more certain that by randomly choosing, we will find a person who will repay the loan.

Entropy therefore measures how mixed or disordered the data is. In the case of decision trees, the algorithm tries to find a partition of the data that minimizes entropy, meaning that groups of data are more homogeneous in terms of belonging to a given class (e.g. will repay or will not repay).

Entropy calculation example

To identify the best split, we need to perform slightly more complicated calculations.

Gini impurity vs entropy

Gini impurity and entropy are used for the same purpose, i.e. evaluating the quality of data partitioning in decision trees, but they use slightly different approaches to the same problem. Gini impurity focuses on minimizing classification errors, while entropy measures the degree of disorder in the data.

I can tell you from experience that they give very similar results. However, I prefer Gini because it is easier to calculate for computers – when there is a large amount of data to calculate, time is money.

Advantages of Decision Trees

Here are my five most important benefits of decision trees:

  1. Ease of interpretation – decision trees generate simple decision rules that are easy to understand and interpret, even for people without specialist knowledge of data analysis. We can visualize them very easily, and this is, in my opinion, their greatest advantage.
  2. Computing time – Compared to other machine learning algorithms such as reinforcement learning or neural networks, decision trees require little computational effort during training and classification.
  3. Robustness to outliers – Decision trees are relatively robust to outliers, meaning that single anomalous data points do not significantly impact their performance.
  4. Requires little data preparation – decision trees can handle continuous variables, eliminating the need to create dummy variables or normalize data.
  5. Robustness to collinearity – Decision trees are robust to collinearity, meaning they can effectively deal with data in which certain features are related.
  6. Automatic modeling of nonlinear dependencies and interactions between variables – this is one of their main advantages over simpler models such as linear regression, which assumes linear relationships between variables. Decision trees divide the feature space into rectangular regions by successive partitions based on feature values. These partitions allow for modeling of complex, nonlinear relationships between input variables and the target variable.
  7. Variable selection – Trees select variables so that they can be used even when there are more of them than observations. This is because the algorithm iteratively selects variables (features) that best divide the data into smaller subsets. Only one variable is selected at each node of the tree, which means that the tree does not have to use all available variables at each node or even in the entire tree.

It is also worth remembering that decision trees can be used for both classification problems (e.g. identifying a fraud transaction) and regression problems (e.g. forecasting real estate prices).

These advantages make decision trees a popular tool in data analysis and machine learning, especially in situations where ease of interpretation and speed of training are important.

Disadvantages of decision trees

Despite their numerous advantages, decision trees also have some disadvantages that are worth mentioning.

  1. Too simple a model – because it is a simple algorithm and most of the problems we face are not so obvious, it is not able to capture enough complexity of the data. That is why very often reinforcement learning algorithms or random forests give better results because they can capture more complexity of the data by aggregating many weak classifiers.
  2. Sensitivity to data changes – decision trees are very sensitive to even small changes in the training data, which can lead to significant changes in the tree structure and final predictions. This can cause model instability, and we don't want such uncertainty in the output.
  3. Tendency to overfit – Decision trees can easily adapt to noise in the training data, which can lead to overfitting the model. Overfitting can result in overfitting to the training data and poor generalization to new data. Luckily, we have methods to deal with this!
  4. Ranking variables – trees only pay attention to the order of values. We can rank all variables (e.g., cece salary sorted from lowest to highest values and assigned a rank) and we will get the same thing. And sometimes the distances between numbers matter!

Thanks toPiotr Szulc for adding some pros and cons (I recommend his LinkedIn – he publishes many interesting entries!).

Python example

Note ! The code below is just to briefly show how easy it is to build a decision tree in python. I am not preparing any variable selection, parameter optimization, etc. etc.

Let's load the necessary packages:

 from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

Let's take the easy way out and generate the data set ourselves.

 # Create a balanced random dataset
X, y = make_classification(n_samples=1000, 
                           n_features=10, 
                           n_classes=2, 
                           weights=[0.5, 0.5], 
                           random_state=2024)

The make_classification function generates a dataset containing a feature matrix (X) and a class label vector (y), which are used to train and test classification models.

Now let's prepare the division of the set:

 # Division into training and test set
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=2024)

Building the model

Now we need to be careful. In sklearn, the decision tree has a default tree depth parameter set to None . This means that the tree will continue to build until the default Gini impurity algorithm finds the next split. This can result in overfitting or overcomplicating the tree.

 # defaultmodel
model = DecisionTreeClassifier(random_state=2024)

model.get_params()

Let's see what model will be built with the default parameters:

 #trainingmodel
model.fit(X_train, y_train)

# calculating AUC on test data
y_pred_proba = model.predict_proba(X_test)[:, 1]
auc = round(roc_auc_score(y_test, y_pred_proba),3)
print(f"max_depth={model.tree_.max_depth}; AUC: {auc}")

We can see that a decision tree has been built up to a depth of 8. So we would have a maximum of 2^8 possibilities, which is 256 different paths! Quite a few possibilities.

Therefore, let's check what AUC metrics we would get depending on the depth of the tree:

 def decision_tree_calc(max_depth=None):
    model = DecisionTreeClassifier(random_state=2024, 
                                   max_depth=max_depth)
    model.fit(X_train, y_train)
    

    y_pred_proba = model.predict_proba(X_test)[:, 1]
    auc = round(roc_auc_score(y_test, y_pred_proba),3)
    print(f"max_depth={model.tree_.max_depth}; AUC: {auc}")
    
    return auc, model

depth_dict = {}

for max_depth in range(1, 11):
    auc, model = decision_tree_calc(max_depth=max_depth)
    depth_dict[max_depth] = auc

Let's draw another picture to help justify our decision to the business that we would suggest a tree depth of two.

 # Extract keys and values from the dictionary
keys = list(depth_dict.keys())
values = list(depth_dict.values())

# Create bar chart
plt.figure(figsize=(10, 5))

# Define color for each bar, set crimson for the highest value bar
colors = ['gray' if v != max(values) else 'crimson' for v in values]

bars = plt.bar(keys, values, color=colors)

# Add values on top of each bar
for bar, value in zip(bars, values):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f'{value:.3f}', ha='center', va='bottom')

# Add title and labels
plt.title('Values of AUC based on decision tree depth')
plt.xlabel('Max depth')
plt.ylabel('AUC on test')

# Set x-axis ticks to show all values
plt.xticks(keys)

#Showplot
plt.show()

Decision Tree - Visualization

As I mentioned earlier, the biggest advantage of decision trees for me is their interpretability. So let's see what such a tree looks like. This will allow decision makers or domain experts to verify for themselves whether the divisions in the tree make sense.

To do this, you must first install graphviz. You can download the program here: https://graphviz.org/download/ . Then we can use it in our code:

 from sklearn.tree import export_graphviz
import graphviz

print(f"graphviz: {graphviz.__version__}")

Let's write a function to save tree graphs:

 def decision_tree_graph(model, file_to_save_name):
    # Export decision tree to DOT format
    dot_data = export_graphviz(model, out_file=None, 
                               filled=True, rounded=True,  
                               special_characters=True)

    #Visualize decision tree
    graph = graphviz.Source(dot_data)
    
    # Save tree as PNG file
    graph.render('./img/' + file_to_save_name, format='png') 
    
    # Open tree in default PDF viewer
    graph.view()  

Now let's visualize a tree with depth 2:

 max_depth = 2
auc, model = decision_tree_calc(max_depth=max_depth)
decision_tree_graph(model, f'graph_max_depth_{str(max_depth)}')

and with a depth of 4:

 max_depth = 4
auc, model = decision_tree_calc(max_depth=max_depth)
decision_tree_graph(model, f'graph_max_depth_{str(max_depth)}')

Below you can see how it is nicely visualized based on a very famous iris collection:

 from sklearn.datasets import load_iris

#Load iris dataset
iris = load_iris()
X, y = iris.data, iris.target

# Train decision tree classifier
iris_model = DecisionTreeClassifier(max_depth=3)
iris_model.fit(X, y)

#graphvisualization
decision_tree_graph(iris_model, f'iris_dataset_max_depth_3')

Free Trick With Decision Trees

You've built a great model, but your bosses want to better understand how it works? Then you can build a decision tree, where you take the probability from your model as the target variable (e.g. <0.2 is 0, and ≥0.2 is 1). Then use simple visualization techniques to show what variables went into it and to explain more easily to business people what the model is based on.

Summary

I hope that now decision trees have no secrets for you. As you could see, decision trees are a simple but very useful tool in the arsenal of machine learning algorithms. I hope that you will sometimes reach for them - especially in the case of models where explainability of decisions plays a key role!

Best regards,