Causal AI, exploring the mixing of causal reasoning into machine learning
This text gives a practical introduction to the potential of causal graphs.
It’s geared toward anyone who wants to know more about:
- What causal graphs are and the way they work
- A worked case study in Python illustrating learn how to construct causal graphs
- How they compare to ML
- The important thing challenges and future considerations
The complete notebook will be found here:
Causal graphs help us disentangle causes from correlations. They’re a key a part of the causal inference/causal ML/causal AI toolbox and will be used to reply causal questions.
Also known as a DAG (directed acyclic graph), a causal graph comprises nodes and edges — Edges link nodes which can be causally related.
There are two ways to find out a causal graph:
- Expert domain knowledge
- Causal discovery algorithms
For now, we’ll assume we have now expert domain knowledge to find out the causal graph (we’ll cover causal discovery algorithms further down the road).
The target of ML is to categorise or predict as accurately as possible given some training data. There isn’t any incentive for an ML algorithm to make sure the features it uses are causally linked to the goal. There isn’t any guarantee that the direction (positive/negative effect) and strength of every feature will align with the true data generating process. ML won’t take into consideration the next situations:
- Spurious correlations — Two variables having a spurious correlation after they have a typical cause e.g. High temperatures increasing the variety of ice cream sales and shark attacks.
- Confounders — A variable is affecting your treatment and end result e.g. Demand affecting how much we spend on marketing and what number of recent customers join.
- Colliders — A variable that’s affected by two independent variables e.g. Quality of customer care -> User satisfaction <- Size of company
- Mediators — Two variables being (not directly) linked through a mediator e.g. Regular exercise -> Cardiovascular fitness (the mediator) -> Overall health
Due to these complexities and the black-box nature of ML, we are able to’t be confident in its ability to reply causal questions.
Given a known causal graph and observed data, we are able to train a structural causal model (SCM). An SCM will be regarded as a series of causal models, one per node. Each model uses one node as a goal, and its direct parents as features. If the relationships in our observed data are linear, an SCM can be a series of linear equations. This might be modelled by a series of linear regression models. If the relationships in our observed data are non-linear, this might be modelled with a series of boosted trees.
The important thing difference to traditional ML is that an SCM models causal relationships and accounts for spurious correlations, confounders, colliders and mediators.
It is not uncommon to make use of an additive noise model (ANM) for every non-root node (meaning it has at the very least one parent). This permits us to make use of a variety of machine learning algorithms (plus a noise term) to estimate each non-root node.
Y := f(X) + N
Root nodes can modelled using a stochastic model to explain the distribution.
An SCM will be seen as a generative model as can to generate recent samples of knowledge — This allows it to reply a variety of causal questions. It generates recent data by sampling from the basis nodes after which propagating data through the graph.
The worth of an SCM is that it allows us to reply causal questions by calculating counterfactuals and simulating interventions:
- Counterfactuals: Using historically observed data to calculate what would have happened to y if we had modified x. e.g. What would have happened to the number of consumers churning if we had reduced call waiting time by 20% last month?
- Interventions: Very much like counterfactuals (and infrequently used interchangeably) but interventions simulate what what would occur in the long run e.g. What’s going to occur to the number of consumers churning if we reduce call waiting time by 20% next 12 months?
There are several KPIs that the client service team monitors. Certainly one of these is call waiting times. Increasing the variety of call centre staff will decrease call waiting times.
But how will decreasing call waiting time impact customer churn levels? And can this offset the associated fee of additional call centre staff?
The Data Science team is asked to construct and evaluate the business case.
The population of interest is customers who make an inbound call. The next time-series data is collected every day:
In this instance, we use time-series data but causal graphs may work with customer-level data.
In this instance, we use expert domain knowledge to find out the causal graph.
# Create node lookup for channels
node_lookup = {0: 'Demand',
1: 'Call waiting time',
2: 'Call abandoned',
3: 'Reported problems',
4: 'Discount sent',
5: 'Churn'
}total_nodes = len(node_lookup)
# Create adjacency matrix - that is the bottom for our graph
graph_actual = np.zeros((total_nodes, total_nodes))
# Create graph using expert domain knowledge
graph_actual[0, 1] = 1.0 # Demand -> Call waiting time
graph_actual[0, 2] = 1.0 # Demand -> Call abandoned
graph_actual[0, 3] = 1.0 # Demand -> Reported problems
graph_actual[1, 2] = 1.0 # Call waiting time -> Call abandoned
graph_actual[1, 5] = 1.0 # Call waiting time -> Churn
graph_actual[2, 3] = 1.0 # Call abandoned -> Reported problems
graph_actual[2, 5] = 1.0 # Call abandoned -> Churn
graph_actual[3, 4] = 1.0 # Reported problems -> Discount sent
graph_actual[3, 5] = 1.0 # Reported problems -> Churn
graph_actual[4, 5] = 1.0 # Discount sent -> Churn
Next, we’d like to generate data for our case study.
We wish to generate some data which is able to allow us to check calculating counterfactuals using causal graphs vs ML (to maintain things easy, ridge regression).
As we identified the causal graph within the last section, we are able to use this data to create a data-generating process.
def data_generator(max_call_waiting, inbound_calls, call_reduction):
'''
A knowledge generating function that has the pliability to scale back the worth of node 0 (Call waiting time) - this allows us to calculate ground truth counterfactualsArgs:
max_call_waiting (int): Maximum call waiting time in seconds
inbound_calls (int): Total variety of inbound calls (observations in data)
call_reduction (float): Reduction to use to call waiting time
Returns:
DataFrame: Generated data
'''
df = pd.DataFrame(columns=node_lookup.values())
df[node_lookup[0]] = np.random.randint(low=10, high=max_call_waiting, size=(inbound_calls)) # Demand
df[node_lookup[1]] = (df[node_lookup[0]] * 0.5) * (call_reduction) + np.random.normal(loc=0, scale=40, size=inbound_calls) # Call waiting time
df[node_lookup[2]] = (df[node_lookup[1]] * 0.5) + (df[node_lookup[0]] * 0.2) + np.random.normal(loc=0, scale=30, size=inbound_calls) # Call abandoned
df[node_lookup[3]] = (df[node_lookup[2]] * 0.6) + (df[node_lookup[0]] * 0.3) + np.random.normal(loc=0, scale=20, size=inbound_calls) # Reported problems
df[node_lookup[4]] = (df[node_lookup[3]] * 0.7) + np.random.normal(loc=0, scale=10, size=inbound_calls) # Discount sent
df[node_lookup[5]] = (0.10 * df[node_lookup[1]] ) + (0.30 * df[node_lookup[2]]) + (0.15 * df[node_lookup[3]]) + (-0.20 * df[node_lookup[4]]) # Churn
return df
# Generate data
np.random.seed(999)
df = data_generator(max_call_waiting=600, inbound_calls=10000, call_reduction=1.00)sns.pairplot(df)
We now have an adjacency matrix which represents our causal graph and a few data. We use the gcm module from the dowhy Python package to coach an SCM.
It’s necessary to take into consideration what causal mechanism to make use of for the basis and non-root nodes. When you have a look at our data generator function, you will note all the relationships are linear. Subsequently selecting ridge regression must be sufficient.
# Setup graph
graph = nx.from_numpy_array(graph_actual, create_using=nx.DiGraph)
graph = nx.relabel_nodes(graph, node_lookup)# Create SCM
causal_model = gcm.InvertibleStructuralCausalModel(graph)
causal_model.set_causal_mechanism('Demand', gcm.EmpiricalDistribution()) # Root node
causal_model.set_causal_mechanism('Call waiting time', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root node
causal_model.set_causal_mechanism('Call abandoned', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root node
causal_model.set_causal_mechanism('Reported problems', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root node
causal_model.set_causal_mechanism('Discount sent', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root
causal_model.set_causal_mechanism('Churn', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Non-root
gcm.fit(causal_model, df)
You might also use the auto task function to routinely assign the causal mechanisms as an alternative of manually assigning them.
For more information on the gcm package see the docs:
We also use ridge regression to assist create a baseline comparison. We are able to look back at the information generator and see that it accurately estimates the coefficients for every variable. Nevertheless, along with directly influencing churn, call waiting time not directly influences churn through abandoned calls, reported problems and discounts sent.
Relating to estimating counterfactuals it will be interesting to see how the SCM compares to ridge regression.
# Ridge regression
y = df['Churn'].copy()
X = df.iloc[:, 1:-1].copy()
model = RidgeCV()
model = model.fit(X, y)
y_pred = model.predict(X)print(f'Intercept: {model.intercept_}')
print(f'Coefficient: {model.coef_}')
# Ground truth[0.10 0.30 0.15 -0.20]
Before we move on to calculating counterfactuals using causal graphs and ridge regression, we’d like a ground truth benchmark. We are able to use our data generator to create counterfactual samples after we have now reduced call waiting time by 20%.
We couldn’t do that with real-world problems but this method allows us to evaluate how effective the causal graph and ridge regression is.
# Set call reduction to twenty%
reduce = 0.20
call_reduction = 1 - reduce# Generate counterfactual data
np.random.seed(999)
df_cf = data_generator(max_call_waiting=600, inbound_calls=10000, call_reduction=call_reduction)
We are able to now estimate what would have happened if we had of decreased the decision waiting time by 20% using our 3 methods:
- Ground truth (from the information generator)
- Ridge regression
- Causal graph
We see that ridge regression underestimates the impact on churn significantly whilst the causal graph could be very near the bottom truth.
# Ground truth counterfactual
ground_truth = round((df['Churn'].sum() - df_cf['Churn'].sum()) / df['Churn'].sum(), 2)# Causal graph counterfactual
df_counterfactual = gcm.counterfactual_samples(causal_model, {'Call waiting time': lambda x: x*call_reduction}, observed_data=df)
causal_graph = round((df['Churn'].sum() - df_counterfactual['Churn'].sum()) / (df['Churn'].sum()), 3)
# Ridge regression counterfactual
ridge_regression = round((df['Call waiting time'].sum() * 1.0 * model.coef_[0] - (df['Call waiting time'].sum() * call_reduction * model.coef_[0])) / (df['Churn'].sum()), 3)
This was a straightforward example to begin you interested by the ability of causal graphs.
For more complex situations, several challenges that may need some consideration:
- What assumptions are made and what’s the impact of those being violated?
- What about if we don’t have the expert domain knowledge to discover the causal graph?
- What if there are non-linear relationships?
- How damaging is multi-collinearity?
- What if some variables have lagged effects?
- How can we cope with high-dimensional datasets (numerous variables)?
All of those points can be covered in future blogs.
In case your considering learning more about causal AI, I highly recommend the next resources:
“Meet Ryan, a seasoned Lead Data Scientist with a specialized give attention to employing causal techniques inside business contexts, spanning Marketing, Operations, and Customer Service. His proficiency lies in unraveling the intricacies of cause-and-effect relationships to drive informed decision-making and strategic enhancements across diverse organizational functions.”