Skip to content

Commit 67181c6

Browse files
authored
Merge pull request #338 from pymc-labs/multi-cell-geolift
Add example analysis of multiple geo lift test analysis
2 parents e90bab9 + f89c53b commit 67181c6

File tree

8 files changed

+1716
-14
lines changed

8 files changed

+1716
-14
lines changed

causalpy/data/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"sc": {"filename": "synthetic_control.csv"},
3434
"anova1": {"filename": "ancova_generated.csv"},
3535
"geolift1": {"filename": "geolift1.csv"},
36+
"geolift_multi_cell": {"filename": "geolift_multi_cell.csv"},
3637
"risk": {"filename": "AJR2001.csv"},
3738
"nhefs": {"filename": "nhefs.csv"},
3839
"schoolReturns": {"filename": "schoolingReturns.csv"},

causalpy/data/geolift_multi_cell.csv

Lines changed: 209 additions & 0 deletions
Large diffs are not rendered by default.

causalpy/data/simulate_data.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def generate_synthetic_control_data(
5858
Generates data for synthetic control example.
5959
6060
:param N:
61-
Number fo data points
61+
Number of data points
6262
:param treatment_time:
6363
Index where treatment begins in the generated dataframe
6464
:param grw_mu:
@@ -324,15 +324,6 @@ def generate_geolift_data():
324324
treatment_time = pd.to_datetime("2022-01-01")
325325
causal_impact = 0.2
326326

327-
def create_series(n=52, amplitude=1, length_scale=2):
328-
"""
329-
Returns numpy tile with generated seasonality data repeated over
330-
multiple years
331-
"""
332-
return np.tile(
333-
generate_seasonality(n=n, amplitude=amplitude, length_scale=2) + 3, n_years
334-
)
335-
336327
time = pd.date_range(start="2019-01-01", periods=52 * n_years, freq="W")
337328

338329
untreated = [
@@ -345,7 +336,12 @@ def create_series(n=52, amplitude=1, length_scale=2):
345336
]
346337

347338
df = (
348-
pd.DataFrame({country: create_series() for country in untreated})
339+
pd.DataFrame(
340+
{
341+
country: create_series(n_years=n_years, intercept=3)
342+
for country in untreated
343+
}
344+
)
349345
.assign(time=time)
350346
.set_index("time")
351347
)
@@ -360,6 +356,67 @@ def create_series(n=52, amplitude=1, length_scale=2):
360356

361357
# add treatment effect
362358
df["Denmark"] += np.where(df.index < treatment_time, 0, causal_impact)
359+
360+
# ensure we never see any negative sales
361+
df = df.clip(lower=0)
362+
363+
return df
364+
365+
366+
def generate_multicell_geolift_data():
367+
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
368+
countries. The treated unit `Denmark` is a weighted combination of the untreated
369+
units. We additionally specify a treatment effect which takes effect after the
370+
`treatment_time`. The timeseries data is observed at weekly resolution and has
371+
annual seasonality, with this seasonality being a drawn from a Gaussian Process with
372+
a periodic kernel."""
373+
n_years = 4
374+
treatment_time = pd.to_datetime("2022-01-01")
375+
causal_impact = 0.2
376+
time = pd.date_range(start="2019-01-01", periods=52 * n_years, freq="W")
377+
378+
untreated = [
379+
"u1",
380+
"u2",
381+
"u3",
382+
"u4",
383+
"u5",
384+
"u6",
385+
"u7",
386+
"u8",
387+
"u9",
388+
"u10",
389+
"u11",
390+
"u12",
391+
]
392+
393+
df = (
394+
pd.DataFrame(
395+
{
396+
country: create_series(n_years=n_years, intercept=3)
397+
for country in untreated
398+
}
399+
)
400+
.assign(time=time)
401+
.set_index("time")
402+
)
403+
404+
treated = ["t1", "t2", "t3", "t4"]
405+
406+
for treated_geo in treated:
407+
# create treated unit as a weighted sum of the untreated units
408+
weights = np.random.dirichlet(np.ones(len(untreated)), size=1)[0]
409+
df[treated_geo] = np.dot(df[untreated].values, weights)
410+
# add treatment effect
411+
df[treated_geo] += np.where(df.index < treatment_time, 0, causal_impact)
412+
413+
# add observation noise to all geos
414+
for col in untreated + treated:
415+
df[col] += np.random.normal(size=len(df), scale=0.1)
416+
417+
# ensure we never see any negative sales
418+
df = df.clip(lower=0)
419+
363420
return df
364421

365422

@@ -387,3 +444,14 @@ def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
387444
return amplitude**2 * np.exp(
388445
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
389446
)
447+
448+
449+
def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
450+
"""
451+
Returns numpy tile with generated seasonality data repeated over
452+
multiple years
453+
"""
454+
return np.tile(
455+
generate_seasonality(n=n, amplitude=amplitude, length_scale=2) + intercept,
456+
n_years,
457+
)

causalpy/tests/test_synthetic_data.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Tests for the simulated data functions
16+
"""
17+
18+
import numpy as np
19+
import pandas as pd
20+
21+
22+
def test_generate_multicell_geolift_data():
23+
"""
24+
Test the generate_multicell_geolift_data function.
25+
"""
26+
from causalpy.data.simulate_data import generate_multicell_geolift_data
27+
28+
df = generate_multicell_geolift_data()
29+
assert type(df) == pd.DataFrame
30+
assert np.all(df >= 0), "Found negative values in dataset"
31+
32+
33+
def test_generate_geolift_data():
34+
"""
35+
Test the generate_geolift_data function.
36+
"""
37+
from causalpy.data.simulate_data import generate_geolift_data
38+
39+
df = generate_geolift_data()
40+
assert type(df) == pd.DataFrame
41+
assert np.all(df >= 0), "Found negative values in dataset"

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/examples.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@ Synthetic Control
1616
notebooks/sc_pymc.ipynb
1717
notebooks/sc_skl.ipynb
1818
notebooks/sc_pymc_brexit.ipynb
19+
20+
Geographical lift testing
21+
=========================
22+
23+
.. toctree::
24+
:titlesonly:
25+
1926
notebooks/geolift1.ipynb
27+
notebooks/multi_cell_geolift.ipynb
2028

2129

2230
Difference in Differences

0 commit comments

Comments
 (0)