Decision Trees and Gini Impurity: A Fun Dive into Data Science
Hello, my fellow data enthusiasts! Buckle up because today we’re venturing into the magical world of decision trees and Gini impurity. Don’t worry this isn’t some dry, soul-sucking math lecture. Nope! We’re going to make this as fun as a coding party (yes, those exist). By the end of this article, you’ll know how to calculate Gini impurity, build a decision tree, and maybe even impress your friends at your next board game night. Let’s get started!
What’s a Decision Tree?
Think of a decision tree as a “Choose Your Own Adventure” book for data. At each step, the tree asks questions to split the data into groups that are as pure as possible. Pure? What does that mean? Imagine sorting socks. If one pile has only black socks and another has only white socks, those piles are pure. But if both piles are a chaotic mix of colors, that’s impure. And nobody likes impure sock piles.
Decision trees aim to minimize this messiness at each split. To measure the messiness, we use something fancy called Gini impurity. Let’s break it down.
Gini Impurity: The Sock Sorting Guru
Gini impurity measures how “messy” a group is. It tells us the probability of misclassifying an item if we randomly pick it from a group. The formula is:

Where pi is the proportion of items in class i.
- If all items belong to one class, Gini impurity is 0 (pure bliss!).
- If items are evenly split between classes, Gini impurity is 0.5 (utter chaos!).
For example:
- A basket with only apples → Gini = 0 (pure).
- A basket with 50% apples and 50% oranges → Gini = 0.5 (messy).
Now that we’ve got the theory down, let’s get coding!
The Code: Building Blocks for Decision Trees
Here’s the Python code that calculates Gini impurity and builds a decision tree step by step. Don’t worry I’ll explain everything along the way.
Step 1: Weighted Average Function
When splitting data into groups, we calculate the weighted average of their impurities. Why? Because larger groups have more influence on the overall impurity.
def calc_wighted_average(im1, imp1_multiplier, im2, imp2_multiplier): return round((((im1 * imp1_multiplier) + (im2 * imp2_multiplier)) / (imp1_multiplier + imp2_multiplier)), 3)
This function combines two groups’ impurities based on their sizes. Think of it like splitting pizza: bigger slices matter more.
Step 2: Calculating Gini Impurity
Here’s the function that calculates Gini impurity for both continuous (e.g., age) and discrete (e.g., diabetes status) features.
def calc_impurity(data): if len(np.unique(data[:, 0])) > 2: # Continuous feature sorted_data = data[data[:, 0].argsort()] main_dict = {} for i in range(1, len(sorted_data)): first_number = sorted_data[i-1, 0] second_number = sorted_data[i, 0] avg = (first_number + second_number) / 2 true_xs = data[data[:, 0] < avg] count_true_xs = len(true_xs) true_xs_true_ys = len(true_xs[true_xs[:, 1] == True]) true_xs_false_ys = len(true_xs[true_xs[:, 1] == False]) imp1 = round(1 - ((true_xs_true_ys / count_true_xs) ** 2) - ((true_xs_false_ys / count_true_xs) ** 2), 3) false_xs = data[data[:, 0] > avg] count_false_xs = len(false_xs) false_xs_true_ys = len(false_xs[false_xs[:, 1] == True]) false_xs_false_ys = len(false_xs[false_xs[:, 1] == False]) imp2 = round(1 - ((false_xs_true_ys / count_false_xs) ** 2) - ((false_xs_false_ys / count_false_xs) ** 2), 3) main_dict[str(avg)] = calc_wighted_average(imp1, count_true_xs, imp2, count_false_xs) return {min(main_dict, key=main_dict.get): main_dict[min(main_dict, key=main_dict.get)]} else: # Discrete feature true_xs = data[data[:, 0] == True] count_true_xs = len(true_xs) if count_true_xs == 0: imp1 = 0 else: true_xs_true_ys = len(true_xs[true_xs[:, 1] == True]) true_xs_false_ys = len(true_xs[true_xs[:, 1] == False]) imp1 = round(1 - ((true_xs_true_ys / count_true_xs) ** 2) - ((true_xs_false_ys / count_true_xs) ** 2), 3) false_xs = data[data[:, 0] == False] count_false_xs = len(false_xs) if count_false_xs == 0: imp2 = 0 else: false_xs_true_ys = len(false_xs[false_xs[:, 1] == True]) false_xs_false_ys = len(false_xs[false_xs[:, 1] == False]) imp2 = round(1 - ((false_xs_true_ys / count_false_xs) ** 2) - ((false_xs_false_ys / count_false_xs) ** 2), 3) return calc_wighted_average(imp1, count_true_xs, imp2, count_false_xs)
This function:
- Handles continuous features by trying different split points.
- Handles discrete features by directly calculating Gini impurity.
Step 3: Meet Our Dataset
Let’s introduce our dataset. It has four columns:
- Pressure: High blood pressure (1) or not (0).
- Diabetes: Diabetes (1) or not (0).
- Age: The person’s age.
- Stroke: Whether they had a stroke (1) or not (0). This is our target variable!
Here’s what it looks like:
df = np.array([[1, 0, 18, 1], [1, 1, 15, 1], [0, 1, 65, 0], [0, 0, 33, 0], [1, 0, 37, 1], [0, 1, 45, 1], [0, 1, 50, 0], [1, 0, 75, 0], [1, 0, 67, 1], [1, 1, 60, 1], [0, 1, 55, 1], [0, 0, 69, 0], [0, 0, 80, 0], [0, 1, 87, 1], [1, 0, 38, 1]])
Our goal is to figure out which feature best predicts strokes. Let’s do this!
Step-by-Step: Building the Decision Tree
Step 1: Calculate Impurities for All Features
We calculate Gini impurity for each feature:
print(calc_impurity(df[:, [0, -1]])) # Pressure print(calc_impurity(df[:, [1, -1]])) # Diabetes print(calc_impurity(df[:, [2, -1]])) # Age
- Pressure: Impurity = X
- Diabetes: Impurity = Y
- Age: Impurity = Z
Guess what? Age has the lowest impurity! So we split on age first.
Step 2: Split on Age
The code finds that splitting at age 68 gives us the purest groups. So we divide our data into two subsets:
- True Side: People younger than 68.
- False Side: People older than or equal to 68.
true_side_df = df[df[:, 2] < 68] false_side_df = df[df[:, 2] >= 68]
Step 3: Recalculate Impurities for Each Subset
Now we calculate impurities again for each subset to decide the next split.
For People Younger than 68:
print(calc_impurity(true_side_df[:, [0, -1]])) # Pressure print(calc_impurity(true_side_df[:, [1, -1]])) # Diabetes
- Diabetes has the lowest impurity here!
For People Older than or Equal to 68:
print(calc_impurity(false_side_df[:, [0, -1]])) # Pressure print(calc_impurity(false_side_df[:, [1, -1]])) # Diabetes
- Pressure wins this round!
The Final Decision Tree
Here’s what our decision tree looks like so far:

Okay, it’s not exactly an oak tree it’s more like a bonsai but you get the idea!
Why Should You Care?
Decision trees are one of the simplest yet most powerful tools in machine learning. They:
- Are easy to interpret.
- Handle both continuous and categorical data.
- Serve as building blocks for more advanced models like Random Forests and Gradient Boosting Machines.
By using Gini impurity to measure how “pure” your splits are, you can ensure your decision tree makes smarter decisions at every step.
Wrapping Up
So there you have it a crash course in decision trees and Gini impurity! We started with a dataset about strokes and used Python to figure out which features matter most. Along the way:
- We learned how to calculate Gini impurity.
- We built a simple decision tree step by step.
If you want to explore further or play with the code yourself (trust me it’s more fun than sorting socks), check out my GitHub repo! Happy coding!