{ "cells": [ { "cell_type": "markdown", "id": "f21133df-9ed8-4757-8963-6e2176c3128e", "metadata": {}, "source": [ "# Balance Calculators" ] }, { "cell_type": "code", "execution_count": 1, "id": "b83085cf-5a4d-4db3-85ff-317d6e376735", "metadata": {}, "outputs": [], "source": [ "from pybalance.utils.balance_calculators import *\n", "from pybalance.utils import MatchingData\n", "from pybalance.sim import load_paper_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "917b2206-9b4e-404d-9ff7-21a251565338", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " Headers Numeric:
\n", " ['age', 'height', 'weight']

\n", " Headers Categoric:
\n", " ['gender', 'haircolor', 'country', 'binary_0', 'binary_1', 'binary_2', 'binary_3']

\n", " Populations
\n", " ['pool', 'target']
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageheightweightgenderhaircolorcountrypopulationbinary_0binary_1binary_2binary_3patient_id
064.854093189.46685088.8350491.014pool0101135740
152.571993158.13494094.2151071.011pool010149288
225.828361154.69248294.2262221.003pool0010256676
370.177571160.53663294.2443561.002pool0001338287
473.779164153.55141986.1618140.001pool001172849
.......................................
27499562.547794186.00501550.9750510.001target0011579081
27499669.879934142.371386100.1383891.014target0110569939
27499756.921402130.639589108.7451821.015target0100532419
27499834.082754174.76405167.9983960.022target0001566266
27499960.981259137.41943689.8978171.005target1111544231
\n", "

275000 rows × 12 columns

\n", "
" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m =load_paper_dataset()\n", "m" ] }, { "cell_type": "code", "execution_count": 13, "id": "49b37f1e-4069-46c4-abba-f6851f0112d1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
N
population
pool250000
target25000
\n", "
" ], "text/plain": [ " N\n", "population \n", "pool 250000\n", "target 25000" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.counts()" ] }, { "cell_type": "markdown", "id": "3eddef9b-b87d-4a0c-ae51-17d204422be9", "metadata": {}, "source": [ "## Fit Balance Calculator" ] }, { "cell_type": "code", "execution_count": 14, "id": "877b9f82-cd01-4be9-8ac4-ad191248d1d7", "metadata": {}, "outputs": [], "source": [ "# Balance calculators in general are \"fit\" to the whole population data\n", "# Fitting here means fitting preprocessors (e.g. what bins to use when binning\n", "# is involved). It's important to fit once so that all calls to distance()\n", "# can be compared meaningfully.\n", "beta = BetaBalance(m)\n", "target, pool = split_target_pool(m)" ] }, { "cell_type": "markdown", "id": "3d588103-f175-43f8-9ab5-38b0a186bb04", "metadata": {}, "source": [ "## Balance between pool and target" ] }, { "cell_type": "code", "execution_count": 15, "id": "af607a84-1eac-4a2f-a8a5-521d3646ce4f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2353, dtype=torch.float64)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "beta.distance(pool)" ] }, { "cell_type": "code", "execution_count": 16, "id": "f6d1a23f-d7ce-4848-8d81-86f3d6bfbd9e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2353, dtype=torch.float64)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Specifying target is optional\n", "beta.distance(pool, target)" ] }, { "cell_type": "markdown", "id": "ceb3e196-21c8-43ad-83f3-12686008bb27", "metadata": {}, "source": [ "## Balance between subset of pool and target" ] }, { "cell_type": "code", "execution_count": 17, "id": "7b07a9f6-02d8-4686-a91f-97d3ed4be40e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2366, dtype=torch.float64)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "beta.distance(pool.sample(n=100))" ] }, { "cell_type": "code", "execution_count": 18, "id": "9e47e4ef-8092-45a1-a274-201b686c5c70", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2669, dtype=torch.float64)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Can also take subsets of the target\n", "beta.distance(pool.sample(n=100), target.sample(n=100))" ] }, { "cell_type": "markdown", "id": "9cb4fcef-14a0-4344-a28a-601911160219", "metadata": {}, "source": [ "## Balance between several subsets simultaneously" ] }, { "cell_type": "code", "execution_count": 19, "id": "c367fc95-aef2-4687-bedf-e34168ea86ce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.2404, 0.2418], dtype=torch.float64)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool_subsets = np.array([\n", " np.random.choice(pool.reset_index().index.values, size=100, replace=False),\n", " np.random.choice(pool.reset_index().index.values, size=100, replace=False)\n", "])\n", "beta.distance(pool_subsets)" ] }, { "cell_type": "code", "execution_count": 9, "id": "10503998-f714-4cfc-921f-1a62f9aebd1d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/gmema/src/pybalance/pybalance/utils/balance_calculators.py:224: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:278.)\n", " subset_populations = torch.tensor(\n" ] }, { "data": { "text/plain": [ "tensor([0.2602, 0.2757], dtype=torch.float64)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool_subsets = [\n", " np.random.choice(pool.reset_index().index.values, size=100, replace=False),\n", " np.random.choice(pool.reset_index().index.values, size=100, replace=False)\n", "]\n", "target_subsets = [\n", " np.random.choice(target.reset_index().index.values, size=100, replace=False),\n", " np.random.choice(target.reset_index().index.values, size=100, replace=False)\n", "]\n", "beta.distance(pool_subsets, target_subsets)" ] }, { "cell_type": "code", "execution_count": 28, "id": "70666f28-d6df-4c4c-8c56-29fb6ed19daa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of subset populations must be same for pool and target!\n" ] } ], "source": [ "# Must have same number of subsets! This will throw an error:\n", "pool_subsets = [\n", " np.random.choice(pool.reset_index().index.values, size=100, replace=False),\n", " np.random.choice(pool.reset_index().index.values, size=100, replace=False)\n", "]\n", "target_subsets = [\n", " np.random.choice(target.reset_index().index.values, size=100, replace=False),\n", " np.random.choice(target.reset_index().index.values, size=100, replace=False),\n", " np.random.choice(target.reset_index().index.values, size=100, replace=False)\n", "]\n", "try:\n", " beta.distance(pool_subsets, target_subsets)\n", "except ValueError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "63857014-b52b-40d9-8d7b-208c7b981129", "metadata": {}, "source": [ "## Basic Genetic Optimizer" ] }, { "cell_type": "markdown", "id": "4300c184-b990-4118-b430-1d997724f03c", "metadata": {}, "source": [ "Here is a very basic, un-optimized implementation of genetic matching! It's not very smart, because it doesn't mix the good populations. This is just an illustration of using the balance calculator." ] }, { "cell_type": "code", "execution_count": 27, "id": "4b8dea69-5f4a-4823-ba3b-b6deaa30bb57", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generation 0 / Best distance found 0.215\n", "Generation 10 / Best distance found 0.215\n", "Generation 20 / Best distance found 0.211\n", "Generation 30 / Best distance found 0.211\n", "Generation 40 / Best distance found 0.208\n", "Generation 50 / Best distance found 0.208\n", "Generation 60 / Best distance found 0.208\n", "Generation 70 / Best distance found 0.208\n", "Generation 80 / Best distance found 0.208\n", "Generation 90 / Best distance found 0.208\n" ] } ], "source": [ "def get_subsets(pool, target, pool_size, target_size, n_subsets):\n", " pool = pool.reset_index()\n", " target = target.reset_index()\n", " \n", " pool_subsets = [\n", " np.random.choice(pool.index.values, size=pool_size, replace=False) for _ in range(n_subsets)\n", " ]\n", " target_subsets = [\n", " np.random.choice(target.index.values, size=target_size, replace=False) for _ in range(n_subsets)\n", " ]\n", " return pool_subsets, target_subsets\n", "\n", "\n", "pool_size = 1000\n", "target_size = 1000\n", "n_subsets = 100\n", "best_match = None\n", "best_distance = 100000\n", "for j in range(100):\n", " pool_subsets, target_subsets = get_subsets(pool, target, pool_size, target_size, n_subsets)\n", " distances = beta.distance(pool_subsets, target_subsets)\n", " this_best_distance = distances.min()\n", " if this_best_distance < best_distance:\n", " best_distance = this_best_distance\n", " best_match_idx = distances.argmin()\n", " best_match = pool_subsets[best_match_idx], target_subsets[best_match_idx]\n", "\n", " if not j % 10:\n", " print(f'Generation {j} / Best distance found {best_distance:.3f}')" ] }, { "cell_type": "code", "execution_count": null, "id": "534bfaa4-f00e-43d1-ac18-2c191466b9b0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pybalance", "language": "python", "name": "pybalance" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }