Over the last few years, model interpretability has become the focus of a great deal of research. Interpretability of models plays a crucial role in trusting the model, spotting leakages (unexpected information that leaked into the training data, causing unrealistically good results), detecting biases, debugging a model, discovering new data patterns, and more.
Today, data professionals are expected to create a model that not only performs well, but can also be explained clearly. The trouble is, businesses frequently need a complicated model in order to obtain reliable predictions, and that usually comes at the price of interpretability.
Different Types of Interpretability
The good news is, we can shed some light on the black box models we create.
First off, we need to explain interpretability. Broadly speaking, there are two ways to classify it:
- Global interpretability explains how the model behaves across the global data population
- Local interpretability explains specific instance prediction
Different applications require different types of interpretability, and in many cases we need to use both types. For example, if a bank uses a model to predict which borrowers to approve for a loan, the data professional first needs to understand what the model is doing in a global manner, and then understand why each specific applicant was refused a loan.
In this blog series, I’m going to dig deep into practical techniques for global interpretability of generic, black box models, which can be extremely complex.
Interpretability Use Case: Customer Churn at a Telco Company
To make it easier, let’s look at customer churn as it impacts the telco industry as a use case to illustrate methods for global interpretability.
Our goal is to build a model that predicts if a customer is going to churn, so that the business can develop focused customer retention programs.
The dataset includes:
- Churn data: customers who left within the past month.
- Demographic data: the customers’ gender, whether or not they have a partner and/or dependents.
- Services data: which services the customer signed up for, e.g. online backup, device protection, movie streaming, etc.
- Customer account data: how long the customer has been with Telco, their contract type, payment method, and charges.
We want a model that produces the most accurate predictions about whether or not a customer will churn, but we also want to be able to understand what causes them to churn, so that we can take action to prevent it from happening.
Choosing an Explainable Machine Learning Model
There are simple machine learning (ML) models, such as decision trees and linear models, that are quite intuitive and easy to interpret. These models are considered highly interpretable, but we usually need to use a more complex model, which is less explainable, in order to achieve the best scores.
Having said that, simple models can still be useful in some circumstances.
- In some industries, compliance leads companies to compromise on the performance of their model in order to obtain better explainability.
- Building a complex model takes a lot of time. Running a simple model first allows you to validate the data, understand it from a high level, and grasp the complexity of the problem, so that you can be sure the complex model is needed before you spend a huge amount of time building one.
Applying an Explainable Model
Let’s begin modeling our churn problem with a decision tree. I used our Firefly.ai system to create this one:
The decision tree model is pretty straightforward. Take a look at the orange pie charts in the decision tree, indicating that there is no churn. It’s easy to see that the probability of churn is very low for cases in which the contract is not month-to-month, and even lower for cases in which the contract is not a one-year contract (i.e. a two-year contract). It’s also easy to explain how we reached this decision; we only need to guide people through the rules in the decision tree.
However, although decision trees are explainable models, the deeper the tree, the harder it will be to fully grasp the explanation and predictions for slightly varying instances.
Choosing a Complex ML Model
Now let’s imagine using a more complex and less explainable model, because we want to improve our prediction scoring. There are a number of factors that can enhance a model, but in the process make it less explainable:
- Before we build the model, we should perform extensive data preprocessing that transforms the data. This preprocessing may include steps that change the features, such as the standardization/normalization of features and imputation (replacing of missing values); steps that add or remove features, such as feature engineering and feature selection; and steps that add or remove samples. This preprocessing stage helps produce better results, but it can also cause the features to be less understandable.
- We use more complex algorithms for the model, like random forest and neural network, which are difficult to explain but yield better results.
- Finally, we combine multiple models into an ensemble, producing an even more complex model that is extremely difficult to explain, but delivers very accurate results.
Applying a Complex Model
In order to arrive at accurate predictions for whether or not a customer will churn, I used our automated machine learning (AutoML) system at full capacity for a complex model. Sure enough, the ensuing results were quite close to those acquired by others assigned with the same task, which is pretty awesome.
Our new, complex model is actually made up of 4 models: 2 logistic regression models, 1 catboost model, and 1 xgboost model.
Each model has its own extensive pipeline of data preprocessing.
Understanding what comprises the model and what happens in each step – which features and samples were added or removed – can be very helpful for trusting and debugging the model. For instance, perhaps during the data cleaning step of preprocessing, many data samples were removed due to missing values. With this information, we can try to find the relevant data and then rerun the model to get better results.
So, we’ve ended up with a black box model that has very good performance but is also very difficult to explain. How will we begin to unbox it?
Practical Guide: Global Interpretability for a Black Box Model
The best first step for unboxing this type of model is sensitivity analysis. Sensitivity analysis takes the whole pile of features which affect our end result, and aims to explain which are most significant for the model. Sensitivity analysis works by examining the impact that each feature in turn has on the results.
Let’s go into more detail about sensitivity analysis:
- First, we choose one feature in the test dataset. We change its value or try to ignore it in some way, while keeping all the other features constant.
- Then we take a look at the new model outcome.
- If changing the feature value had a significant impact on the model output, it reveals that this feature is important for the final prediction.
- If we don’t see that this feature has a significant impact, we can move on to change a different feature in the dataset to see if it has a bigger impact.
In using this technique of sensitivity analysis, we take one feature and change its value. There are many ways to do that. One simple and well-known way is called permutation. With the permutation technique, you permute the values of the column relating to this feature. This way, you maintain the same distribution of the feature as before the transformation.
You can read more about sensitivity analysis, transformation types, and other practical issues here.
Sensitivity Analysis for Customer Churn
When we ran sensitivity analysis with permutation for our churn use case, we saw the following results:
It’s clear that the type of contract that a customer signs has the greatest impact on customer churn. That leads us to try changing our pricing plans to offer a different contract to some customers. It’s also evident that internet service has a larger impact than movie streaming, so we’ll create offers that focus on internet service rather than on streaming movies.
The Pros and Cons of Sensitivity Analysis
- We can use it to access powerful insights into the importance of model features
- Sensitivity analysis also considers all interactions with other features
- Changing feature values can lead to some highly unrealistic data instances which could bias the result in the wrong way
It’s important to note that different transformations may generate different results of feature importance, so you need to take the nature of the transformation into consideration when reading the results. It is also important to note that the importance of correlative features may be divided among them, so a feature that seems to have low importance might actually be more significant than it seems in the sensitivity analysis results.
Partial Dependence Plot – PDP
Although sensitivity analysis revealed that “contract” has the greatest impact on churn, it still leaves us asking which contract will succeed in reducing churn?
As well as identifying which features affect churn, we also want to understand how each feature affects it. That’s where Partial Dependence Plot (PDP) comes in.
PDP is a technique that reveals the marginal effect that each feature has on the model’s predicted outcome. With PDP, we build a graph for each feature, presenting the averaged model predictions as a function of the feature values. PDP can also present the outcome as a function of multiple features but the visualization will be more complex; for example, for two features PDP will generate a 2D visualization. We should also mention that this technique provides clearer visualization for numerical features and for categorical features with a logical order, while for categorical features with many categories the visualization may be a bit unclear.
Let’s go into more detail about how to use PDP. PDP should be done for each feature we want to explore, or for the important features revealed by sensitivity analysis. This is the procedure for creating PDP for a specific feature:
- We’ll begin with Individual Conditional Expectation (ICE) graphs, which are created per instance and per feature. Given an instance, ICE creates a graph showing how prediction alters when we change the values of a feature. We take many instances from our test set and create ICE graphs for them.
- Then, we take the average of those instances and generate a PDP graph.
- With the same method of calculation and presentation used in PDP, we can also create graphs presenting other measurements such as standard deviation, highest and smallest percentile lines, minimum, and maximum.
- And remember, be sure to present the feature distribution as well, so as not to rely heavily on a feature’s region that has minimal data.
Applying PDP to Churn Use Case
So now that we understand that “contract” is the feature with the greatest impact on whether or not a person will churn, we can run PDP on this feature alone to see how it affects the prediction.
Below I’ve shared a PDP graph with standard deviation, and feature distribution.
As you can see, the longer the contract, the lower the probability that the customer will churn. Also, you can see that when a customer has a two-year contract, the probability of churn is very close to zero, though this is something we already learned from the decision tree that we created earlier. According to the feature distribution, a quarter of the data relates to customers who have a two-year contract. This is enough for us to draw reliable conclusions regarding this region.
This PDP graph also shows us that the probability of churn is still low for the one-year contract. According to standard deviation, most of the time the probability the customer will churn is less than 0.4. However, for a month-to-month contract, the standard deviation is very high, meaning that there are other features affecting the result.
To further understand and collect insights for the data regarding month-to-month contracts, we could present the sensitivity analysis, while taking only this region of data into consideration (meaning to only use data where the feature contract term was month-to-month). After understanding which features are important in this region, we can run PDP for the relevant features. But we will stop here for now, and instead continue with a discussion around the pros and cons of PDP.
Although in the graph above I presented only PDP and standard deviation, be sure to present other graphs (such as percentile, minimum, maximum), as mentioned above, for deeper understanding of your model.
The Pros and Cons of PDP
- The PDP technique delivers global insight into how each feature affects the model
- It’s a simple and intuitive method that’s relatively easy to use
- PDP presents the trend of a feature, but there can be more than one trend per feature. For example, in our churn problem and the monthly charges feature, there can be different trends for different groups within a population. To overcome this issue, one can also present ICE plots, or use another method instead of averaging (clustering, for instance).
- As in sensitivity analysis, results can be biased by unrealistic and misleading data instances.
The Benefits of Global Explainability
Businesses often need to be able to explain their black box models, but they also need complex models for prediction accuracy. We’ve briefly explored two techniques for achieving global explainability, and discussed what we can gain from using each of these approaches. The techniques and the ways in which we use them should depend on what we want to gain from the explanation. In our use case of churn prediction, part of our desire for explainability comes from our need to be able to trust the model, spot leakages, and debug, but is mostly because we want to extract additional insights from the model, to empower us to take action accordingly.
Einat Naaman is a Machine Learning Researcher at Firefly.ai.