{
"cells": [
{
"cell_type": "markdown",
"id": "4f580f07",
"metadata": {},
"source": [
"# Constraint Satisfaction Matcher\n",
"\n",
"The ConstraintSatisfactionMatcher can be used to optimize any linear function of the baseline covariates. We support constraints on the size of the subset populations and the allowed mismatch.\n",
"\n",
"Here, we demonstrate the optimization of balance subject to size constraints only. Namely, we solve:\n",
"\n",
"\\begin{equation}\n",
"\\begin{aligned}\n",
"& \\underset{\\hat{P}}{\\text{minimize}}\n",
"& & \\sum_k |\\mu_{\\hat{P}k} - \\mu_{Tk}| \\\\\n",
"& \\text{subject to}\n",
"& & |\\hat{P}| = P^* \\\\\n",
"& & & |\\hat{T}| = T^* \\\\\n",
"\\end{aligned}\n",
"\\end{equation}\n",
"\n",
"where $P$ and $T$ refer to two populations we are trying to match, $\\hat{P}$ and $\\hat{T}$ are the subsets of $P$ and $T$ we are seeking, $P^*$ and $T^*$ are fixed integers, and $k$ indexes the covariates of $P$ and $T$."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0f723264-db60-46d9-846d-b8dc17998db1",
"metadata": {},
"outputs": [],
"source": [
"import logging \n",
"logging.basicConfig(\n",
" format=\"%(levelname)-4s [%(filename)s:%(lineno)d] %(message)s\",\n",
" level='INFO',\n",
")\n",
"from pybalance.utils import (\n",
" BetaBalance, \n",
" BetaXBalance, \n",
" GammaBalance, \n",
" GammaXBalance,\n",
" GammaXTreeBalance\n",
")\n",
"from pybalance.sim import generate_toy_dataset\n",
"from pybalance.lp import ConstraintSatisfactionMatcher\n",
"from pybalance.visualization import (\n",
" plot_numeric_features, \n",
" plot_categoric_features, \n",
" plot_binary_features,\n",
" plot_per_feature_loss,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6bd8a3d5-3c19-466f-8994-91b394935bb0",
"metadata": {},
"outputs": [],
"source": [
"time_limit = 360"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2d42b61e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" Headers Numeric:
\n",
" ['age', 'height', 'weight']
\n",
" Headers Categoric:
\n",
" ['gender', 'haircolor', 'country', 'binary_0', 'binary_1', 'binary_2', 'binary_3']
\n",
" Populations
\n",
" ['pool', 'target']
\n",
"
\n", " | age | \n", "height | \n", "weight | \n", "gender | \n", "haircolor | \n", "country | \n", "population | \n", "binary_0 | \n", "binary_1 | \n", "binary_2 | \n", "binary_3 | \n", "patient_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "62.511573 | \n", "190.229250 | \n", "105.165097 | \n", "0.0 | \n", "2 | \n", "3 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "68.505065 | \n", "161.121236 | \n", "95.001474 | \n", "0.0 | \n", "1 | \n", "1 | \n", "pool | \n", "1 | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "
2 | \n", "50.071384 | \n", "162.325356 | \n", "84.290576 | \n", "1.0 | \n", "0 | \n", "5 | \n", "pool | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "2 | \n", "
3 | \n", "44.423692 | \n", "150.948096 | \n", "82.031381 | \n", "1.0 | \n", "2 | \n", "2 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "3 | \n", "
4 | \n", "41.695052 | \n", "132.952651 | \n", "54.857540 | \n", "0.0 | \n", "1 | \n", "3 | \n", "pool | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "4 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
995 | \n", "21.474205 | \n", "168.602546 | \n", "70.342128 | \n", "0.0 | \n", "2 | \n", "5 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "10995 | \n", "
996 | \n", "40.643320 | \n", "188.188724 | \n", "61.611744 | \n", "0.0 | \n", "2 | \n", "4 | \n", "target | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "10996 | \n", "
997 | \n", "29.472765 | \n", "161.408162 | \n", "57.214095 | \n", "0.0 | \n", "0 | \n", "1 | \n", "target | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "10997 | \n", "
998 | \n", "41.291949 | \n", "150.968833 | \n", "91.270798 | \n", "0.0 | \n", "0 | \n", "3 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10998 | \n", "
999 | \n", "67.530294 | \n", "155.124741 | \n", "56.196505 | \n", "1.0 | \n", "0 | \n", "1 | \n", "target | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "10999 | \n", "
11000 rows × 12 columns
\n", "\n", " | age | \n", "height | \n", "weight | \n", "gender | \n", "haircolor | \n", "country | \n", "population | \n", "binary_0 | \n", "binary_1 | \n", "binary_2 | \n", "binary_3 | \n", "patient_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "55.261578 | \n", "139.396134 | \n", "94.438359 | \n", "0.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "10000 | \n", "
1 | \n", "63.113091 | \n", "165.563337 | \n", "67.433016 | \n", "1.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "10001 | \n", "
2 | \n", "58.232216 | \n", "160.859857 | \n", "71.915385 | \n", "1.0 | \n", "0 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10002 | \n", "
3 | \n", "58.996941 | \n", "140.357415 | \n", "115.606615 | \n", "1.0 | \n", "0 | \n", "3 | \n", "target | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "10003 | \n", "
4 | \n", "36.850195 | \n", "189.983706 | \n", "53.000581 | \n", "0.0 | \n", "2 | \n", "5 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10004 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9933 | \n", "68.194783 | \n", "127.495418 | \n", "69.177329 | \n", "0.0 | \n", "1 | \n", "5 | \n", "pool | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "9933 | \n", "
9946 | \n", "22.630370 | \n", "185.351623 | \n", "117.381552 | \n", "1.0 | \n", "1 | \n", "2 | \n", "pool | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "9946 | \n", "
9955 | \n", "56.736759 | \n", "161.612045 | \n", "72.288182 | \n", "1.0 | \n", "2 | \n", "2 | \n", "pool | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "9955 | \n", "
9981 | \n", "39.006118 | \n", "133.419182 | \n", "71.135407 | \n", "0.0 | \n", "1 | \n", "4 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "9981 | \n", "
9982 | \n", "50.575808 | \n", "139.401060 | \n", "89.848616 | \n", "0.0 | \n", "1 | \n", "1 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "9982 | \n", "
2000 rows × 12 columns
\n", "\n", " | age | \n", "height | \n", "weight | \n", "gender | \n", "haircolor | \n", "country | \n", "population | \n", "binary_0 | \n", "binary_1 | \n", "binary_2 | \n", "binary_3 | \n", "patient_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "55.261578 | \n", "139.396134 | \n", "94.438359 | \n", "0.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "10000 | \n", "
1 | \n", "63.113091 | \n", "165.563337 | \n", "67.433016 | \n", "1.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "10001 | \n", "
2 | \n", "58.232216 | \n", "160.859857 | \n", "71.915385 | \n", "1.0 | \n", "0 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10002 | \n", "
3 | \n", "58.996941 | \n", "140.357415 | \n", "115.606615 | \n", "1.0 | \n", "0 | \n", "3 | \n", "target | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "10003 | \n", "
4 | \n", "36.850195 | \n", "189.983706 | \n", "53.000581 | \n", "0.0 | \n", "2 | \n", "5 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10004 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9933 | \n", "68.194783 | \n", "127.495418 | \n", "69.177329 | \n", "0.0 | \n", "1 | \n", "5 | \n", "pool | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "9933 | \n", "
9946 | \n", "22.630370 | \n", "185.351623 | \n", "117.381552 | \n", "1.0 | \n", "1 | \n", "2 | \n", "pool | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "9946 | \n", "
9955 | \n", "56.736759 | \n", "161.612045 | \n", "72.288182 | \n", "1.0 | \n", "2 | \n", "2 | \n", "pool | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "9955 | \n", "
9981 | \n", "39.006118 | \n", "133.419182 | \n", "71.135407 | \n", "0.0 | \n", "1 | \n", "4 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "9981 | \n", "
9982 | \n", "50.575808 | \n", "139.401060 | \n", "89.848616 | \n", "0.0 | \n", "1 | \n", "1 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "9982 | \n", "
2000 rows × 12 columns
\n", "\n", " | age | \n", "height | \n", "weight | \n", "gender | \n", "haircolor | \n", "country | \n", "population | \n", "binary_0 | \n", "binary_1 | \n", "binary_2 | \n", "binary_3 | \n", "patient_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "55.261578 | \n", "139.396134 | \n", "94.438359 | \n", "0.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "10000 | \n", "
1 | \n", "63.113091 | \n", "165.563337 | \n", "67.433016 | \n", "1.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "10001 | \n", "
2 | \n", "58.232216 | \n", "160.859857 | \n", "71.915385 | \n", "1.0 | \n", "0 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10002 | \n", "
3 | \n", "58.996941 | \n", "140.357415 | \n", "115.606615 | \n", "1.0 | \n", "0 | \n", "3 | \n", "target | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "10003 | \n", "
4 | \n", "36.850195 | \n", "189.983706 | \n", "53.000581 | \n", "0.0 | \n", "2 | \n", "5 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10004 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9933 | \n", "68.194783 | \n", "127.495418 | \n", "69.177329 | \n", "0.0 | \n", "1 | \n", "5 | \n", "pool | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "9933 | \n", "
9947 | \n", "64.290077 | \n", "168.091011 | \n", "63.511962 | \n", "1.0 | \n", "2 | \n", "2 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "9947 | \n", "
9958 | \n", "51.722321 | \n", "170.350117 | \n", "80.695438 | \n", "0.0 | \n", "2 | \n", "4 | \n", "pool | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "9958 | \n", "
9982 | \n", "50.575808 | \n", "139.401060 | \n", "89.848616 | \n", "0.0 | \n", "1 | \n", "1 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "9982 | \n", "
9983 | \n", "68.616093 | \n", "167.546870 | \n", "58.683367 | \n", "1.0 | \n", "0 | \n", "2 | \n", "pool | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "9983 | \n", "
2000 rows × 12 columns
\n", "\n", " | age | \n", "height | \n", "weight | \n", "gender | \n", "haircolor | \n", "country | \n", "population | \n", "binary_0 | \n", "binary_1 | \n", "binary_2 | \n", "binary_3 | \n", "patient_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "55.261578 | \n", "139.396134 | \n", "94.438359 | \n", "0.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "10000 | \n", "
1 | \n", "63.113091 | \n", "165.563337 | \n", "67.433016 | \n", "1.0 | \n", "2 | \n", "2 | \n", "target | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "10001 | \n", "
2 | \n", "58.232216 | \n", "160.859857 | \n", "71.915385 | \n", "1.0 | \n", "0 | \n", "2 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10002 | \n", "
3 | \n", "58.996941 | \n", "140.357415 | \n", "115.606615 | \n", "1.0 | \n", "0 | \n", "3 | \n", "target | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "10003 | \n", "
4 | \n", "36.850195 | \n", "189.983706 | \n", "53.000581 | \n", "0.0 | \n", "2 | \n", "5 | \n", "target | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "10004 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9932 | \n", "72.514956 | \n", "159.248205 | \n", "118.505187 | \n", "0.0 | \n", "1 | \n", "5 | \n", "pool | \n", "1 | \n", "1 | \n", "0 | \n", "1 | \n", "9932 | \n", "
9933 | \n", "68.194783 | \n", "127.495418 | \n", "69.177329 | \n", "0.0 | \n", "1 | \n", "5 | \n", "pool | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "9933 | \n", "
9965 | \n", "41.035792 | \n", "130.021437 | \n", "80.495109 | \n", "0.0 | \n", "0 | \n", "1 | \n", "pool | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "9965 | \n", "
9966 | \n", "40.121009 | \n", "168.339212 | \n", "100.428001 | \n", "1.0 | \n", "2 | \n", "4 | \n", "pool | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "9966 | \n", "
9984 | \n", "57.366166 | \n", "151.483411 | \n", "82.271539 | \n", "1.0 | \n", "2 | \n", "2 | \n", "pool | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "9984 | \n", "
2000 rows × 12 columns
\n", "