Causal AI, exploring the integration of causal reasoning into machine learning
Photo by Boris Dunand on Unsplash
What is this series of articles about?
Welcome to my series on Causal AI, where we will explore the integration of causal reasoning into machine learning models. Expect to explore a number of practical applications across different business contexts.
In the last article we covered enhancing marketing mix modelling with Causal AI. In this article we will move onto safeguarding demand forecasting with causal graphs.
If you missed the last article on marketing mix modelling, check it out here:
Enhancing Marketing Mix Modelling with Causal AI
Introduction
In this article we will delve into how you can safeguard demand forecasting (or any forecasting use case to be honest) with causal graphs.
The following areas will be explored:
A quick forecasting 101.What is demand forecasting?A refresher on causal graphs.How can causal graphs safeguard demand forecasting?A Python case study illustrating how causal graphs can safeguard your forecasts from spurious correlations.
The full notebook can be found here:
Forecasting
Forecasting 101
Time series forecasting involves predicting future values based on historical observations.
User generated image
To start us off, there are a number of terms which it is worth getting familiar with:
Auto-correlation — The correlation of a series with it’s previous values at different time lags. Helps identify if there is a trend present.Stationary — This is when the statistical properties of a series are constant over time (e.g. mean, variance). Some forecasting methods assume stationarity.Differencing — This is when we subtract the previous observation from the current observation to transform a non-stationary series into a stationary one. An important step for models which assume stationarity.Seasonality — A regular repeating cycle which occurs at a fixed interval (e.g. daily, weekly, yearly).Trend — The long term movement in a series.Lag — The number of time steps between an observation and a previous value.Residuals — The difference between predicted and actual values.Moving average — Used to smooth out short term fluctuations by averaging a fixed number of past observations.Exponential smoothing — Weights are applied to past observations, with more emphasis placed on recent values.Seasonal decomposition — This is when we separate a time series into seasonal, trend and residual components.User generated image
There a a number of different methods which can be used for forecasting:
ETS (Error, Trend, Seasonal) — An exponential smoothing method that models error, trend and seasonality components.Autoregressive models (AR models) — Models the current value of the series as a linear combination of it’s previous values.Moving average models (MA models) — Models the current value of the series as a linear combination of past forecast errors.Autoregressive integrated moving average (ARIMA models) — Combines AR and MA models with the incorporation of differencing to make the series stationary.State space models — Deconstructs the timeseries into individual components such as trend and seasonality.Hierarchical models — A method which handles data structured in a hierarchy such as regions.Linear regression — Uses one or more independent variable (feature) to predict the dependent variable (target).Machine learning (ML) — Uses more flexible algorithms like boosting to capture complex relationships.
If you want to dive further into this topic, I highly recommend the following resource which is well known as the go-to guide for forecasting (the version below is free 😀):
Forecasting: Principles and Practice (3rd ed)
In terms of applying some of the forecasting models using Python, I’d recommend exploring Nixtla which has an extensive list of models implemented and an easy to use API:
Demand forecasting
Predicting the demand for your product is important.
It can help manage your inventory, avoiding over or understocking.It can keep your customers satisfied, ensuring products are available when they want them.Reducing holding costs and minimising waste is cost efficient.Essential for strategic planning.
Keeping demand forecasts accurate is essential — In the next section let’s start to think about how causal graphs could safeguard our forecasts…
Causal graphs
Causal graph refresher
I’ve covered causal graphs a few times in my series, but just in case you need a refresher check out my first article where I cover it in detail:
Using Causal Graphs to answer causal questions
How can causal graphs safeguard demand forecasting?
Taking the graph below as an example, let’s say we want to forecast our target variable. We find we have 3 variables which are correlated with it, so we use them as features. Why would including the spurious correlation be a problem? The more features we include the better our forecast right?
User generated image
Well, not really….
When it comes to demand forecasting one of the major problems is data drift. Data drift in itself isn’t a problem if the relationship between the feature of interest and target remain constant. But when the relationship doesn’t remain constant, our forecasting accuracy will deteriorate.
But how is a causal graph going to help us… The idea is that spurious correlations are much more likely to drift, and much more likely to cause problems when they do.
Not convinced? OK it’s time to jump into the case study then!
Case study
Background
Your friend has bought an ice cream van. They paid a consultant a lot of money to build them a demand forecast model. It worked really well for the first few months, but in the last couple of months your friend has been understocking ice cream! They remember that your job title was “data something or other” and come to you for advice.
Creating the case study data
Let me start by explaining how I created the data for this case study. I created a simple causal graph with the following characteristics:
Ice cream sales is the target node (X0)Coastal visits is a direct cause of ice cream sales (X1)Temperature is an indirect cause of ice cream sales (X2)Sharks attacks is a spurious correlation (X3)User generated image
I then used the following data generating process:
User generated image
You can see that each node is influenced by past values of itself and a noise term as well as it’s direct parents. To create the data I use a handy module from the time series causal analysis python package Tigramite:
Tigramite is a great package but I am not going to cover it in detail this time around as is deserves it own article! Below we use the structural_causal_process module following the data generating process above:
seed=42
np.random.seed(seed)
# create node lookup for channels
node_lookup = {0: ‘ice cream sales’,
1: ‘coastal visits’,
2: ‘temperature’,
3: ‘shark attacks’,
}
# data generating process
def lin_f(x):
return x
links_coeffs = {0: [((0, -1), 0.2, lin_f), ((1, -1), 0.9, lin_f)],
1: [((1, -1), 0.5, lin_f), ((2, -1), 1.2, lin_f)],
2: [((2, -1), 0.7, lin_f)],
3: [((3, -1), 0.2, lin_f), ((2, -1), 1.8, lin_f) ],
}
# time series length
T = 1000
data, _ = toys.structural_causal_process(links_coeffs, T=T, seed=seed)
T, N = data.shape
# create var name lookup
var_names = [node_lookup[i] for i in sorted(node_lookup.keys())]
# initialize dataframe object, specify time axis and variable names
df = pp.DataFrame(data,
datatime = {0:np.arange(len(data))},
var_names=var_names)
We can then visualise our time series:
tp.plot_timeseries(df)
plt.show()User generated image
Now you understand how I have created the data, lets get back to the case study in the next section!
Understanding the data generating process
You start by trying to understand the data generating process by taking the data used in the model. There are 3 features included in the model:
Coastal visitsTemperatureShark attacks
To get an understanding of the causal graph, you use PCMCI (which is has a great implementation in Tigramite), a method which is suitable for causal time series discovery. I am not going to cover PCMCI this time round as it needs it’s own dedicated article. However, if you are unfamiliar with causal discovery in general, use my previous article to get a good introduction:
Making Causal Discovery work in real-world business settings
User generated image
The causal graph output from PCMCI can be seen above. The following things jump out:
Coastal visits is a direct cause of ice cream salesTemperature is an in-direct cause of ice cream salesSharks attacks is a spurious correlation
You question why anyone with any common sense would include shark attacks as a feature! Looking at the documentation it seems that the consultant used ChatGPT to get a list of features to consider for the model and then used autoML to train the model.
So if ChatGPT and autoML think shark attacks should be in the model, surely it can’t be doing any harm?
Pre-processing the case study data
Next let’s visit how I pre-processed the data to make it suitable for this case study. To create our features we need to pick up the lagged values for each column (look back at the data generating process to understand why the features need to be the lagged values):
# create dataframne
df_pd = pd.DataFrame(df.values[0], columns=var_names)
# calcuate lagged values for each column
lag_periods = 1
for col in var_names:
df_pd[f'{col}_lag{lag_periods}’] = df_pd[col].shift(lag_periods)
# remove 1st obervations where we don’t have lagged values
df_pd = df_pd.iloc[1:, :]
df_pdUser generated image
We could use these lagged features to predict ice cream sales, but before we do let’s introduce some data drift to the spurious correlation:
# function to introduce feature drift based on indexes
def introduce_feature_drift(df, start_idx, end_idx, drift_amount):
drift_period = (df.index >= start_idx) & (df.index <= end_idx)
df.loc[drift_period, ‘shark attacks_lag1’] += np.linspace(0, drift_amount, drift_period.sum())
return df
# introduce feature drift
df_pd = introduce_feature_drift(df_pd, start_idx=500, end_idx=999, drift_amount=50.0)
# visualise drift
plt.figure(figsize=(12, 6))
sns.lineplot(data=df_pd[[‘shark attacks_lag1’]])
plt.title(‘Feature Drift Over Time’)
plt.xlabel(‘Index’)
plt.ylabel(‘Value’)
plt.legend([‘shark attacks_lag1’])
plt.show()User generated image
Let’s go back to the case study and understand what we are seeing. Why has the number of shark attacks drifted? You do some research and find out that one of the causes of shark attacks is the number of people surfing. In recent months there has been a huge rise in the popularity of surfing, causing an increase in shark attacks. So how did this effect the ice cream sales forecasting?
Model training
You decide to recreate the model using the same features as the consultant and then using just the direct causes:
# use first 500 observations for training
df_train = df_pd.iloc[0:500, :]
# use last 100 observations for evaluation
df_test = df_pd.iloc[900:, :]
# set feature lists
X_causal_cols = [“ice cream sales_lag1”, “coastal visits_lag1”]
X_spurious_cols = [“ice cream sales_lag1”, “coastal visits_lag1”, “temperature_lag1”, “shark attacks_lag1”]
# create target, train and test sets
y_train = df_train[‘ice cream sales’].copy()
y_test = df_test[‘ice cream sales’].copy()
X_causal_train = df_train[X_causal_cols].copy()
X_causal_test = df_test[X_causal_cols].copy()
X_spurious_train = df_train[X_spurious_cols].copy()
X_spurious_test = df_test[X_spurious_cols].copy()
The model trained on just the direct causes looks good on both the train and test set.
# train and validate model
model_causal = RidgeCV()
model_causal = model_causal.fit(X_causal_train, y_train)
print(f’Coefficient: {model_causal.coef_}’)
yhat_causal_train = model_causal.predict(X_causal_train)
yhat_causal_test = model_causal.predict(X_causal_test)
mse_train = mean_squared_error(y_train, yhat_causal_train)
mse_test = mean_squared_error(y_test, yhat_causal_test)
print(f”Mean Squared Error train: {round(mse_train, 2)}”)
print(f”Mean Squared Error test: {round(mse_test, 2)}”)
r2_train = r2_score(y_train, yhat_causal_train)
r2_test = r2_score(y_test, yhat_causal_test)
print(f”R2 train: {round(r2_train, 2)}”)
print(f”R2 test: {round(r2_test, 2)}”)User generated image
However, when you train the model using all of the features you see that the model performs well on the train set but not the test set. Seem’s like you identified the problem!
# train and validate model
model_spurious = RidgeCV()
model_spurious = model_spurious.fit(X_spurious_train, y_train)
print(f’Coefficient: {model_spurious.coef_}’)
yhat_spurious_train = model_spurious.predict(X_spurious_train)
yhat_spurious_test = model_spurious.predict(X_spurious_test)
mse_train = mean_squared_error(y_train, yhat_spurious_train)
mse_test = mean_squared_error(y_test, yhat_spurious_test)
print(f”Mean Squared Error train: {round(mse_train, 2)}”)
print(f”Mean Squared Error test: {round(mse_test, 2)}”)
r2_train = r2_score(y_train, yhat_spurious_train)
r2_test = r2_score(y_test, yhat_spurious_test)
print(f”R2 train: {round(r2_train, 2)}”)
print(f”R2 test: {round(r2_test, 2)}”) User generated image
When we compare the predictions from both models of the test set we can see why your friend has been understocking on ice cream!
# combine results
df_comp = pd.DataFrame({
‘Index’: np.arange(99),
‘Actual’: y_test,
‘Causal prediction’: yhat_causal_test,
‘Spurious prediction’: yhat_spurious_test
})
# melt the DataFrame to long format for seaborn
df_melted = df_comp.melt(id_vars=[‘Index’], value_vars=[‘Actual’, ‘Causal prediction’, ‘Spurious prediction’], var_name=’Series’, value_name=’Value’)
# visualise results for test set
plt.figure(figsize=(12, 6))
sns.lineplot(data=df_melted, x=’Index’, y=’Value’, hue=’Series’)
plt.title(‘Actual vs Predicted’)
plt.xlabel(‘Index’)
plt.ylabel(‘Value’)
plt.legend(title=’Series’)
plt.show()User generated image
Closing thoughts
Today we explored how harmful including spurious correlations in your forecasting models can be. Let’s finish off with some closing thoughts:
The aim of this article was to start you thinking about how understanding the causal graph can improve your forecasts.I know the example was a little over-exaggerated (I would hope common sense would have helped in this scenario!) but it hopefully illustrates the point.Another interesting point to mention is that the coefficient for shark attacks was negative. This is another pitfall as logically we would have expected this spurious correlation to be positive.Medium-long term demand forecasting it very hard — You often need a forecasting model for each feature to be able to forecast multiple timesteps ahead. Interesting, causal graphs (specifically structural causal models) lend themselves well to this problem.
Follow me if you want to continue this journey into Causal AI — In the next article we see how we can use encouragement design to estimate the effect of product features which need to be fully rolled out (no AB test).
Safeguarding Demand Forecasting with Causal Graphs was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.