Imbalanced classes put "accuracy" out of business. This is a surprisingly common problem in machine learning (specifically in classification), occurring in datasets with a disproportionate ratio of observations in each class.
Standard accuracy no longer reliably measures performance, which makes model training much trickier.
Imbalanced classes appear in many domains, including:
- Fraud detection
- Spam filtering
- Disease screening
- SaaS subscription churn
- Advertising click-throughs
In this guide, we'll explore 5 effective ways to handle imbalanced classes.
Intuition: Disease Screening Example
Let's say your client is a leading research hospital, and they've asked you to train a model for detecting a disease based on biological inputs collected from patients.
But here's the catch‚?¶ the disease is relatively rare; it occurs in only 8% of patients who are screened.
Now, before you even start, do you see how the problem might break? Imagine if you didn't bother training a model at all. Instead, what if you just wrote a single line of code that always predicts ‚??No Disease?'
Well, guess what? Your "solution" would have 92% accuracy!
Unfortunately, that accuracy is misleading.
- For patients who do not have the disease, you'd have 100% accuracy.
- For patients who do have the disease, you'd have 0% accuracy.
- Your overall accuracy would be high simply because most patients do not have the disease (not because your model is any good).
This is clearly a problem because many machine learning algorithms are designed to maximize overall accuracy. The rest of this guide will illustrate different tactics for handling imbalanced classes.
Important notes before we begin:
First, please note that we're not going to split out a separate test set, tune hyperparameters, or implement cross-validation. In other words, we're not going to follow best practices (which are covered in our Data Science Primer).
Instead, this tutorial is focused purely on addressing imbalanced classes.
In addition, not every technique below will work for every problem. However, 9 times out of 10, at least one of these techniques should do the trick.
Balance Scale Dataset
For this guide, we'll use a synthetic dataset called Balance Scale Data, which you can download from the UCI Machine Learning Repository here
This dataset was originally generated to model psychological experiment results, but it's useful for us because it's a manageable size and has imbalanced classes.
The dataset contains information about whether a scale is balanced or not, based on weights and distances of the two arms.
- It has 1 target variable, which we've labeled balance.
- It has 4 input features, which we've labeled var1 through var4 .
The target variable has 3 classes.
- R for right-heavy, i.e. when var3 * var4 > var1 * var2
- L for left-heavy, i.e. when var3 * var4 < var1 * var2
- B for balanced, i.e. when var3 * var4 = var1 * var2
However, for this tutorial, we're going to turn this into a binary classification problem.
We're going to label each observation as 1 (positive class) if the scale is balanced or 0 (negative class) if the scale is not balanced:
As you can see, only about 8% of the observations were balanced. Therefore, if we were to always predict 0, we'd achieve an accuracy of 92%.
The Danger of Imbalanced Classes
Now that we have a dataset, we can really show the dangers of imbalanced classes.
First, let's import the Logistic Regression algorithm and the accuracy metric from Scikit-Learn
As mentioned above, many machine learning algorithms are designed to maximize overall accuracy by default.
We can confirm this:
As you can see, this model is only predicting 0, which means it's completely ignoring the minority class in favor of the majority class.
Next, we'll look at the first technique for handling imbalanced classes: up-sampling the minority class.
1. Up-sample Minority Class
Up-sampling is the process of randomly duplicating observations from the minority class in order to reinforce its signal.
There are several heuristics for doing so, but the most common way is to simply resample with replacement.
First, we'll import the resampling module from Scikit-Learn:
Next, we'll create a new DataFrame with an up-sampled minority class. Here are the steps:
Here's the code:
- First, we'll separate observations from each class into different DataFrames.
- Next, we'll resample the minority class with replacement, setting the number of samples to match that of the majority class.
- Finally, we'll combine the up-sampled minority class DataFrame with the original majority class DataFrame.
As you can see, the new DataFrame has more observations than the original, and the ratio of the two classes is now 1:1.
Let's train another model using Logistic Regression, this time on the balanced dataset:
Great, now the model is no longer predicting just one class. While the accuracy also took a nosedive, it's now more meaningful as a performance metric.
2. Down-sample Majority Class
Down-sampling involves randomly removing observations from the majority class to prevent its signal from dominating the learning algorithm.
The most common heuristic for doing so is resampling without replacement.
The process is similar to that of up-sampling. Here are the steps:
- First, we'll separate observations from each class into different DataFrames.
- Next, we'll resample the majority class without replacement, setting the number of samples to match that of the minority class.
- Finally, we'll combine the down-sampled majority class DataFrame with the original minority class DataFrame.
Here's the code:
This time, the new DataFrame has fewer observations than the original, and the ratio of the two classes is now 1:1.
Again, let's train a model using Logistic Regression:
The model isn't predicting just one class, and the accuracy seems higher.
We'd still want to validate the model on an unseen test dataset, but the results are more encouraging.
3. Change Your Performance Metric
So far, we've looked at two ways of addressing imbalanced classes by resampling the dataset. Next, we'll look at using other performance metrics for evaluating the models.
Albert Einstein once said, "if you judge a fish on its ability to climb a tree, it will live its whole life believing that it is stupid." This quote really highlights the importance of choosing the right evaluation metric.
For a general-purpose metric for classification, we recommend Area Under ROC Curve (AUROC).
- We won't dive into its details in this guide, but you can read more about it here.
- Intuitively, AUROC represents the likelihood of your model distinguishing observations from two classes.
- In other words, if you randomly select one observation from each class, what's the probability that your model will be able to "rank" them correctly?
We can import this metric from Scikit-Learn:
To calculate AUROC, you'll need predicted class probabilities instead of just the predicted classes. You can get them using the .predict_proba() function like so:
So how did this model (trained on the down-sampled dataset) do in terms of AUROC?
Ok... and how does this compare to the original model trained on the imbalanced dataset?
Remember, our original model trained on the imbalanced dataset had an accuracy of 92%, which is much higher than the 58% accuracy of the model trained on the down-sampled dataset.
However, the latter model has an AUROC of 57%, which is higher than the 53% of the original model (but not by much).
Note: if you got an AUROC of 0.47, it just means you need to invert the predictions because Scikit-Learn is misinterpreting the positive class. AUROC should be >= 0.5.
4. Penalize Algorithms (Cost-Sensitive Training)
The next tactic is to use penalized learning algorithms that increase the cost of classification mistakes on the minority class.
A popular algorithm for this technique is Penalized-SVM:
During training, we can use the argument class_weight='balanced' to penalize mistakes on the minority class by an amount proportional to how under-represented it is.
We also want to include the argument probability=True if we want to enable probability estimates for SVM algorithms.
Let's train a model using Penalized-SVM on the original imbalanced dataset:
Again, our purpose here is only to illustrate this technique. To really determine which of these tactics works best for this problem, you'd want to evaluate the models on a hold-out test set.
5. Use Tree-Based Algorithms
The final tactic we'll consider is using tree-based algorithms. Decision trees often perform well on imbalanced datasets because their hierarchical structure allows them to learn signals from both classes.
In modern applied machine learning, tree ensembles (Random Forests, Gradient Boosted Trees, etc.) almost always outperform singular decision trees, so we'll jump right into those:
Now, let's train a model using a Random Forest on the original imbalanced dataset.
Wow! 97% accuracy and nearly 100% AUROC? Is this magic? A sleight of hand? Cheating? Too good to be true?
Well, tree ensembles have become very popular because they perform extremely well on many real-world problems. We certainly recommend them wholeheartedly.
While these results are encouraging, the model could be overfitted, so you should still evaluate your model on an unseen test set before making the final decision.
Note: your numbers may differ slightly due to the randomness in the algorithm. You can set a random seed for reproducible results.
There were a few tactics that didn't make it into this tutorial:
Create Synthetic Samples (Data Augmentation)
Creating synthetic samples is a close cousin of up-sampling, and some people might categorize them together. For example, the SMOTE algorithm
is a method of resampling from the minority class while slightly perturbing feature values, thereby creating "new" samples.
*Update: One of our readers, Marco, brought up a great point about the risks of using SMOTE without proper cross-validation. Check out the comments section for more details or read his blog post on the topic
Combine Minority Classes
Combining minority classes of your target variable may be appropriate for some multi-class problems.
For example, let's say you wished to predict credit card fraud. In your dataset, each method of fraud may be labeled separately, but you might not care about distinguishing them. You could combine them all into a single 'Fraud' class and treat the problem as a binary classification.
Reframe as Anomaly Detection
Anomaly detection, a.k.a. outlier detection is for detecting outliers and rare events
. Instead of building a classification model, you'd have a "profile" of a normal observation. If a new observation strays too far from that "normal profile," it would be flagged as an anomaly.
Conclusion & Next Steps
In this guide, we covered 5 tactics for handling imbalanced classes in machine learning:
- Up-sample the minority class
- Down-sample the majority class
- Change your performance metric
- Penalize algorithms (cost-sensitive training)
- Use tree-based algorithms
These tactics are subject to the No Free Lunch theorem, and you should try several of them and use the results from the test set to decide on the best solution for your problem.