Matching Data

In this notebook, we demonstrate the MatchingData class, which organizes population data for matching, and some plotting tools for visualizing the data. You can download this notebook to run yourself here: https://github.com/Bayer-Group/pybalance/blob/main/sphinx/demos/matching_data.ipynb.

[1]:
import os
import logging
logging.basicConfig(
    format="%(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
    level='INFO',
)

import pandas as pd

import pybalance
from pybalance import MatchingData, MatchingHeaders, split_target_pool
from pybalance.visualization import (
    plot_numeric_features,
    plot_categoric_features,
    plot_binary_features,
    plot_joint_numeric_distributions,
    plot_joint_numeric_categoric_distributions,
    plot_per_feature_loss
)
from pybalance.sim import get_paper_dataset_path

Initializing MatchingData

The MatchingData class is a thin wrapper around pandas DataFrame that additionally keeps track of certain metadata about the columns relevant for matching. MatchingData can be initialized from either a string or pandas DataFrame.

[2]:
# initialize MatchingData from path
data_path = get_paper_dataset_path()
m = MatchingData(data=data_path)
m
[2]:
Headers Numeric:
['age', 'height', 'weight']

Headers Categoric:
['gender', 'haircolor', 'country', 'binary_0', 'binary_1', 'binary_2', 'binary_3']

Populations
['pool', 'target']
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
0 64.854093 189.466850 88.835049 1.0 1 4 pool 0 1 0 1 135740
1 52.571993 158.134940 94.215107 1.0 1 1 pool 0 1 0 1 49288
2 25.828361 154.692482 94.226222 1.0 0 3 pool 0 0 1 0 256676
3 70.177571 160.536632 94.244356 1.0 0 2 pool 0 0 0 1 338287
4 73.779164 153.551419 86.161814 0.0 0 1 pool 0 0 1 1 72849
... ... ... ... ... ... ... ... ... ... ... ... ...
274995 62.547794 186.005015 50.975051 0.0 0 1 target 0 0 1 1 579081
274996 69.879934 142.371386 100.138389 1.0 1 4 target 0 1 1 0 569939
274997 56.921402 130.639589 108.745182 1.0 1 5 target 0 1 0 0 532419
274998 34.082754 174.764051 67.998396 0.0 2 2 target 0 0 0 1 566266
274999 60.981259 137.419436 89.897817 1.0 0 5 target 1 1 1 1 544231

275000 rows × 12 columns

[3]:
# initialize MatchingData from pandas DataFrame
data = pd.read_parquet(data_path)
m = MatchingData(data=data)
m
[3]:
Headers Numeric:
['age', 'height', 'weight']

Headers Categoric:
['gender', 'haircolor', 'country', 'binary_0', 'binary_1', 'binary_2', 'binary_3']

Populations
['pool', 'target']
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
0 64.854093 189.466850 88.835049 1.0 1 4 pool 0 1 0 1 135740
1 52.571993 158.134940 94.215107 1.0 1 1 pool 0 1 0 1 49288
2 25.828361 154.692482 94.226222 1.0 0 3 pool 0 0 1 0 256676
3 70.177571 160.536632 94.244356 1.0 0 2 pool 0 0 0 1 338287
4 73.779164 153.551419 86.161814 0.0 0 1 pool 0 0 1 1 72849
... ... ... ... ... ... ... ... ... ... ... ... ...
274995 62.547794 186.005015 50.975051 0.0 0 1 target 0 0 1 1 579081
274996 69.879934 142.371386 100.138389 1.0 1 4 target 0 1 1 0 569939
274997 56.921402 130.639589 108.745182 1.0 1 5 target 0 1 0 0 532419
274998 34.082754 174.764051 67.998396 0.0 2 2 target 0 0 0 1 566266
274999 60.981259 137.419436 89.897817 1.0 0 5 target 1 1 1 1 544231

275000 rows × 12 columns

MatchingData will infer which covariates to use for matching and the separation of these into numeric and categoric, unless explicitly specified. Here we specify a subset of the covariates to use for matching. Note that the unused columns are still present in the data, but will simply not be used for matching.

[4]:
headers = MatchingHeaders(
    categoric=['country', 'gender', 'binary_0', 'binary_1'],
    numeric=['age', 'weight', 'height']
)
m_restricted_features = MatchingData(
    data=data,
    headers=headers
)
m_restricted_features
[4]:
Headers Numeric:
['age', 'weight', 'height']

Headers Categoric:
['country', 'gender', 'binary_0', 'binary_1']

Populations
['pool', 'target']
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
0 64.854093 189.466850 88.835049 1.0 1 4 pool 0 1 0 1 135740
1 52.571993 158.134940 94.215107 1.0 1 1 pool 0 1 0 1 49288
2 25.828361 154.692482 94.226222 1.0 0 3 pool 0 0 1 0 256676
3 70.177571 160.536632 94.244356 1.0 0 2 pool 0 0 0 1 338287
4 73.779164 153.551419 86.161814 0.0 0 1 pool 0 0 1 1 72849
... ... ... ... ... ... ... ... ... ... ... ... ...
274995 62.547794 186.005015 50.975051 0.0 0 1 target 0 0 1 1 579081
274996 69.879934 142.371386 100.138389 1.0 1 4 target 0 1 1 0 569939
274997 56.921402 130.639589 108.745182 1.0 1 5 target 0 1 0 0 532419
274998 34.082754 174.764051 67.998396 0.0 2 2 target 0 0 0 1 566266
274999 60.981259 137.419436 89.897817 1.0 0 5 target 1 1 1 1 544231

275000 rows × 12 columns

Exploring MatchingData

The describe*() methods can be used to generate summary tables of the matching covariates.

[5]:
m.describe(normalize=False)
[5]:
pool target
population size N 250000.00 25000.00
gender 0.0 120054.00 12956.00
1.0 129946.00 12044.00
haircolor 0.0 100096.00 4924.00
1.0 75185.00 10055.00
2 74719.00 10021.00
country 0.0 0.00 2490.00
1.0 25033.00 5045.00
2 49534.00 4981.00
3 75337.00 2474.00
4 74934.00 5010.00
5 25162.00 5000.00
binary_0 0.0 225028.00 17535.00
1.0 24972.00 7465.00
binary_1 0.0 175673.00 12527.00
1.0 74327.00 12473.00
binary_2 0.0 125113.00 17472.00
1.0 124887.00 7528.00
binary_3 0.0 49933.00 12562.00
1.0 200067.00 12438.00
age mean 55.27 48.33
std 13.18 14.39
min 18.01 18.01
q25 46.38 37.29
median 57.15 48.74
q75 66.10 59.85
max 75.00 75.00
height mean 159.13 153.68
std 19.84 16.45
min 125.00 125.00
q25 142.09 140.29
median 158.74 152.75
q75 175.87 165.95
max 195.00 195.00
weight mean 88.30 82.25
std 16.32 18.89
min 50.00 50.00
q25 76.39 66.14
median 88.85 81.69
q75 100.88 97.41
max 120.00 120.00

You can access fields on the underlying data similarly to how you would in pandas.

[6]:
m[['population', 'gender']]
[6]:
population gender
0 pool 1.0
1 pool 1.0
2 pool 1.0
3 pool 1.0
4 pool 0.0
... ... ...
274995 target 0.0
274996 target 1.0
274997 target 1.0
274998 target 0.0
274999 target 1.0

275000 rows × 2 columns

[7]:
m[m['gender'] == 0]
[7]:
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
4 73.779164 153.551419 86.161814 0.0 0 1 pool 0 0 1 1 72849
7 67.404918 132.383184 67.107753 0.0 0 5 pool 0 0 0 1 171211
11 61.489148 140.780034 73.662572 0.0 0 1 pool 0 0 0 1 20695
12 73.718093 133.743721 58.879321 0.0 1 4 pool 0 0 1 0 58718
13 70.707782 156.629048 70.681391 0.0 0 2 pool 0 1 1 1 352801
... ... ... ... ... ... ... ... ... ... ... ... ...
274990 19.063519 167.704149 59.876565 0.0 2 4 target 0 1 1 1 536365
274992 58.745450 146.747313 70.291448 0.0 2 4 target 0 0 0 1 535279
274993 55.736083 132.434020 92.264209 0.0 1 4 target 0 1 0 0 595582
274995 62.547794 186.005015 50.975051 0.0 0 1 target 0 0 1 1 579081
274998 34.082754 174.764051 67.998396 0.0 2 2 target 0 0 0 1 566266

133010 rows × 12 columns

Often our matching data consists of exactly two populations, a reference population, which we call the “target” and a population to be matched, which we call the “pool”. It is sometimes convenient to split these two populations and the function split_target_pool does just that. The function will assign the smaller population to the target, unless explicitly given the name of the target population. Note that the returned values are pandas DataFrame objects and not MatchingData objects.

[8]:
target, pool = split_target_pool(m)
target.head()
[8]:
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
250000 57.266010 159.759575 94.325267 0.0 1 4 target 0 1 0 1 512966
250001 53.152645 145.515410 95.988094 0.0 1 2 target 1 0 1 1 540606
250002 34.079212 166.272208 73.090671 0.0 1 2 target 0 0 1 1 578266
250003 45.494927 144.336677 96.678251 1.0 2 5 target 1 1 1 1 559858
250004 18.036012 174.843524 60.586475 0.0 1 2 target 1 1 0 1 588368
[9]:
target, pool = split_target_pool(m, target_name='pool')
pool.head()
[9]:
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
250000 57.266010 159.759575 94.325267 0.0 1 4 target 0 1 0 1 512966
250001 53.152645 145.515410 95.988094 0.0 1 2 target 1 0 1 1 540606
250002 34.079212 166.272208 73.090671 0.0 1 2 target 0 0 1 1 578266
250003 45.494927 144.336677 96.678251 1.0 2 5 target 1 1 1 1 559858
250004 18.036012 174.843524 60.586475 0.0 1 2 target 1 1 0 1 588368

Visualizing MatchingData

Some built-in tools help you get a quick visual snapshot of the data. Many of these plotting routines are thin wrappers around seaborn plotting routines with extra logic relevant to matching situations (e.g. where one of the populations is a reference population or where variables should be treated as numeric / categoric). In most cases, the user can pass along any keyword arguments that are understood by the underlying seaborn routine.

[10]:
%matplotlib inline
[11]:
m
[11]:
Headers Numeric:
['age', 'height', 'weight']

Headers Categoric:
['gender', 'haircolor', 'country', 'binary_0', 'binary_1', 'binary_2', 'binary_3']

Populations
['pool', 'target']
age height weight gender haircolor country population binary_0 binary_1 binary_2 binary_3 patient_id
0 64.854093 189.466850 88.835049 1.0 1 4 pool 0 1 0 1 135740
1 52.571993 158.134940 94.215107 1.0 1 1 pool 0 1 0 1 49288
2 25.828361 154.692482 94.226222 1.0 0 3 pool 0 0 1 0 256676
3 70.177571 160.536632 94.244356 1.0 0 2 pool 0 0 0 1 338287
4 73.779164 153.551419 86.161814 0.0 0 1 pool 0 0 1 1 72849
... ... ... ... ... ... ... ... ... ... ... ... ...
274995 62.547794 186.005015 50.975051 0.0 0 1 target 0 0 1 1 579081
274996 69.879934 142.371386 100.138389 1.0 1 4 target 0 1 1 0 569939
274997 56.921402 130.639589 108.745182 1.0 1 5 target 0 1 0 0 532419
274998 34.082754 174.764051 67.998396 0.0 2 2 target 0 0 0 1 566266
274999 60.981259 137.419436 89.897817 1.0 0 5 target 1 1 1 1 544231

275000 rows × 12 columns

[12]:
# Plot the standardized mean difference for each feature
from pybalance.utils import BetaBalance
bc = BetaBalance(m, standardize_difference=True)
fig = plot_per_feature_loss(m, bc)
../_images/demos_matching_data_22_0.png
[13]:
fig = plot_categoric_features(m, cumulative=False, palette='colorblind', include_binary=False)
../_images/demos_matching_data_23_0.png
[14]:
fig = plot_numeric_features(m, bins=10, cumulative=False, palette='colorblind')
../_images/demos_matching_data_24_0.png
[15]:
fig = plot_binary_features(m, palette='colorblind', orient_horizontal=False, standardize_difference=False)
/Users/sprivite/src/pybalance/venv/pybalance/lib/python3.9/site-packages/pybalance/visualization/distributions.py:331: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  plt.gca().set_yticklabels([""] * len(labels), minor=True)
../_images/demos_matching_data_25_1.png