– Dad, you're always on that computer. Maybe we should go for a walk as a family? Jagódka looked out the window – or maybe not. It's not raining, but it's cold. Let's stay home and play board games!
– Yay!! Board games! – Otylka shouted with joy.
– Give me 5 minutes. I'll just finish creating the decision tree – I replied.
- A tree? You don't have any trees on your screen, Daddy.
- Hmm... A decision tree is a kind of plan that helps us make a decision. You've just created such a tree yourself - I smiled.
– Really? – Jagoda looked in disbelief at Otylka.
- Yes! You wanted to go for a walk. At first you looked out the window and checked if it was raining or not. Then you followed the appropriate branches, analyzing other elements, such as temperature. And you got the answer whether to go for a walk.
– So a decision tree is a plan that helps us make good decisions!
– Exactly, Jagódka! It's a tool that helps us understand what factors influence our decisions and what consequences may result from our choices.
A decision tree is a powerful tool in the field of artificial intelligence and machine learning, which is used in various areas, from data classification to solving decision problems. I invite you to a fascinating journey through the world of decision trees, one of the simplest, but despite its simplicity, very useful tools that we have in the arsenal of machine learning.
In simple words, a decision tree is a type of plan or map that helps in making simple decisions based on various options and conditions.
In other words, a decision tree is a graphical structure in which each node represents a decision and each branch represents a possible outcome of that decision.
Above I have visualized a decision tree that was built in Jagódka's head based on her previous walks and her subjective experience. To make it easier for us to discuss the decision tree in the context of machine learning, let's use another example - a more business one.
Let our task be to classify bank customers, whether a given person will repay the loan or will have problems repaying the money with interest.
We will treat each person in our set as a separate observation. The people in our set will be described by different characteristics. Our goal is to separate people who will repay the loan from those who will have a problem with it by asking questions based on the data we have.
Let's say we have 100 people, of whom 50 have repaid their loan and 50 have not. We also have the following information:
a) have you ever had problems with repayment in the last 3 years (Yes/No),
b) average systematic earnings over the last six months (e.g. 3,000, 5,000, etc.).
Then our tree could look like this:
First, it is worth discussing some terminology used in decision trees:
So, according to the above terminology, our tree looks like this:
The purpose of a decision tree is to segregate objects into classes (in our case: will repay / will not repay), and an ideal decision tree (i.e. prediction model in this case) would be one that creates final leaves containing only one class.
Okay, but how does the decision tree know what feature to take and how to arrange the data into a decision tree?
It depends on the algorithm chosen for the decision tree creation process. The most popular algorithms are:
It is worth emphasizing that the CART implementation is most often used in algorithms, e.g. in the sklearn or xgboost libraries.
In fact, in CART we can use several methods to find the best splits for our data. The most well-known are Gini impurity and entropy. Both methods are used to evaluate the quality of the split of the data set. However, there are slight differences between them.
Consider the situation of a credit analyst who is tasked with dividing 100 people into two groups: those who will repay a loan and those who may have a problem doing so. His goal is to divide them so that the group of people standing next to each other (say, in the same room) ultimately contains only people who will repay the loan or not. This way, we will invite one group to the bank and, unfortunately, we will not accept the other.
When sorting people into two groups, the analyst tries to minimize the mixing of people in each group. Gini impurity in this case measures how often, when drawing two people from different groups, we get different outcomes. If one group contains mostly people who will repay the loan and the other mostly people who might have a problem with it, then the Gini impurity will be low because the chances of selecting a person who will repay the loan from one group and a person who will not repay from the other group will be high (close to 1). And after all, 1-something close to one gives a value close to zero.
However, if the groups are mixed, meaning that one group contains both people who will and will not repay the loan, then the Gini impurity will be high (1 minus low probability). This means that the chances of selecting a person who will repay the loan or those who will not from a given group will be similar.
Let's look at our example on the left side of the node. For the numbers we see there, we can calculate that:
Now let's enumerate the last leaves on the left branch of our tree.
And this is where the magic happens when we provide a categorical or continuous variable for modeling, because the algorithm itself looks for the optimal division of the given variable. So our question "systematic influences ≥3k+?" will be compared with influences > 5k+, 10k+, etc.
And how do we compare and find the best partition of our variable (or variables, if we have more)? We do this by calculating the Gini impurity of each of these parts and sum them, taking into account the proportion of elements in each of them. Then we choose the partition that minimizes this sum, meaning that it best segregates the data by class.
So for us it looks like this:
What if we assumed a different earnings value, e.g. 5,000?
For the division by earnings > 5,000, the weighted sum is higher than for the division by earnings of 3,000, so we adopt that division.
Note. It is worth remembering that if the Gini impurity for two child nodes is not lower than the Gini impurity for the parent node, the algorithm will stop looking for splits.
And that's it – simple, right?
Let's go back to the analyst who has to divide the customers. We have divided the customers into two groups and we start to wonder how mixed up they are.
If in one of the groups after division there are more or less the same number of people who will repay the loan and will not repay it and we would like to randomly select one person from this group, then if they are evenly distributed, then entropy will be high. Why? Because we are not sure which person we will find (the one who will repay or the one who will not repay).
However, if most people in a group repay the loan and only a few raisins have problems, the entropy will be lower because we are more certain that by randomly choosing, we will find a person who will repay the loan.
Entropy therefore measures how mixed or disordered the data is. In the case of decision trees, the algorithm tries to find a partition of the data that minimizes entropy, meaning that groups of data are more homogeneous in terms of belonging to a given class (e.g. will repay or will not repay).
To identify the best split, we need to perform slightly more complicated calculations.
Gini impurity and entropy are used for the same purpose, i.e. evaluating the quality of data partitioning in decision trees, but they use slightly different approaches to the same problem. Gini impurity focuses on minimizing classification errors, while entropy measures the degree of disorder in the data.
I can tell you from experience that they give very similar results. However, I prefer Gini because it is easier to calculate for computers – when there is a large amount of data to calculate, time is money.
Here are my five most important benefits of decision trees:
It is also worth remembering that decision trees can be used for both classification problems (e.g. identifying a fraud transaction) and regression problems (e.g. forecasting real estate prices).
These advantages make decision trees a popular tool in data analysis and machine learning, especially in situations where ease of interpretation and speed of training are important.
Despite their numerous advantages, decision trees also have some disadvantages that are worth mentioning.
Thanks toPiotr Szulc for adding some pros and cons (I recommend his LinkedIn – he publishes many interesting entries!).
Note ! The code below is just to briefly show how easy it is to build a decision tree in python. I am not preparing any variable selection, parameter optimization, etc. etc.
Let's load the necessary packages:
from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import roc_auc_score import matplotlib.pyplot as plt
Let's take the easy way out and generate the data set ourselves.
# Create a balanced random dataset X, y = make_classification(n_samples=1000, n_features=10, n_classes=2, weights=[0.5, 0.5], random_state=2024)
The make_classification
function generates a dataset containing a feature matrix (X) and a class label vector (y), which are used to train and test classification models.
Now let's prepare the division of the set:
# Division into training and test set X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=2024)
Now we need to be careful. In sklearn, the decision tree has a default tree depth parameter set to None
. This means that the tree will continue to build until the default Gini impurity algorithm finds the next split. This can result in overfitting or overcomplicating the tree.
# defaultmodel model = DecisionTreeClassifier(random_state=2024) model.get_params()
Let's see what model will be built with the default parameters:
#trainingmodel model.fit(X_train, y_train) # calculating AUC on test data y_pred_proba = model.predict_proba(X_test)[:, 1] auc = round(roc_auc_score(y_test, y_pred_proba),3) print(f"max_depth={model.tree_.max_depth}; AUC: {auc}")
We can see that a decision tree has been built up to a depth of 8. So we would have a maximum of 2^8 possibilities, which is 256 different paths! Quite a few possibilities.
Therefore, let's check what AUC metrics we would get depending on the depth of the tree:
def decision_tree_calc(max_depth=None): model = DecisionTreeClassifier(random_state=2024, max_depth=max_depth) model.fit(X_train, y_train) y_pred_proba = model.predict_proba(X_test)[:, 1] auc = round(roc_auc_score(y_test, y_pred_proba),3) print(f"max_depth={model.tree_.max_depth}; AUC: {auc}") return auc, model depth_dict = {} for max_depth in range(1, 11): auc, model = decision_tree_calc(max_depth=max_depth) depth_dict[max_depth] = auc
Let's draw another picture to help justify our decision to the business that we would suggest a tree depth of two.
# Extract keys and values from the dictionary keys = list(depth_dict.keys()) values = list(depth_dict.values()) # Create bar chart plt.figure(figsize=(10, 5)) # Define color for each bar, set crimson for the highest value bar colors = ['gray' if v != max(values) else 'crimson' for v in values] bars = plt.bar(keys, values, color=colors) # Add values on top of each bar for bar, value in zip(bars, values): plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f'{value:.3f}', ha='center', va='bottom') # Add title and labels plt.title('Values of AUC based on decision tree depth') plt.xlabel('Max depth') plt.ylabel('AUC on test') # Set x-axis ticks to show all values plt.xticks(keys) #Showplot plt.show()
As I mentioned earlier, the biggest advantage of decision trees for me is their interpretability. So let's see what such a tree looks like. This will allow decision makers or domain experts to verify for themselves whether the divisions in the tree make sense.
To do this, you must first install graphviz. You can download the program here: https://graphviz.org/download/ . Then we can use it in our code:
from sklearn.tree import export_graphviz import graphviz print(f"graphviz: {graphviz.__version__}")
Let's write a function to save tree graphs:
def decision_tree_graph(model, file_to_save_name): # Export decision tree to DOT format dot_data = export_graphviz(model, out_file=None, filled=True, rounded=True, special_characters=True) #Visualize decision tree graph = graphviz.Source(dot_data) # Save tree as PNG file graph.render('./img/' + file_to_save_name, format='png') # Open tree in default PDF viewer graph.view()
Now let's visualize a tree with depth 2:
max_depth = 2 auc, model = decision_tree_calc(max_depth=max_depth) decision_tree_graph(model, f'graph_max_depth_{str(max_depth)}')
and with a depth of 4:
max_depth = 4 auc, model = decision_tree_calc(max_depth=max_depth) decision_tree_graph(model, f'graph_max_depth_{str(max_depth)}')
Below you can see how it is nicely visualized based on a very famous iris collection:
from sklearn.datasets import load_iris #Load iris dataset iris = load_iris() X, y = iris.data, iris.target # Train decision tree classifier iris_model = DecisionTreeClassifier(max_depth=3) iris_model.fit(X, y) #graphvisualization decision_tree_graph(iris_model, f'iris_dataset_max_depth_3')
You've built a great model, but your bosses want to better understand how it works? Then you can build a decision tree, where you take the probability from your model as the target variable (e.g. <0.2 is 0, and ≥0.2 is 1). Then use simple visualization techniques to show what variables went into it and to explain more easily to business people what the model is based on.
I hope that now decision trees have no secrets for you. As you could see, decision trees are a simple but very useful tool in the arsenal of machine learning algorithms. I hope that you will sometimes reach for them - especially in the case of models where explainability of decisions plays a key role!
Best regards,