Source code for pybalance.utils.balance_calculators

from typing import Dict, Union, Optional
from collections import defaultdict

import pandas as pd
import numpy as np
import torch

from pybalance.utils import (
    MatchingData,
    split_target_pool,
    BaseMatchingPreprocessor,
    DecisionTreeEncoder,
    StandardMatchingPreprocessor,
    BetaXPreprocessor,
    GammaPreprocessor,
    GammaXPreprocessor,
)

import logging

logger = logging.getLogger(__name__)


def map_input_output_weights(
    preprocessor: BaseMatchingPreprocessor,
    weights: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
    """
    Map weights on input features to weights at the output of a given
    preprocessor. This mapping is only possible if the preprocessor defines
    get_feature_names_out(). Weights from the input variable are passed along to
    all output variables. For instance, if "age" is an input variable and gets a
    weight of 10, then each age bin feature will have a weight of 10.

    We also considered "diluting" weights such that the total initial weight is
    spread across the bins. We found, however, that this gives unsatisfactory
    results, since requirements for matching success are usually stated in terms
    of unweighted bins (e.g. |SMD| < 0.1). If for instance age is an important
    feature, then so should any feature constructed from age.
    """
    if weights is None:
        weights = {}
    input_weights = defaultdict(lambda: 1)
    input_weights.update(weights)
    output_weights = defaultdict(lambda: 1)

    for input_col in preprocessor.input_headers["all"]:
        output_cols = preprocessor.get_feature_names_out(input_col)
        weight_in = input_weights[input_col]

        # Pass weight along to derived features
        weight_out = weight_in
        for output_col in output_cols:
            # if we've seen this column before, something is wrong;
            # an output column should correspond to exactly one input column
            assert output_col not in output_weights.keys()
            output_weights[output_col] = weight_out

    assert set(output_weights.keys()) == set(preprocessor.output_headers["all"])

    return output_weights


def reshape_output(f):
    def _f(*args, **kwargs):
        output = f(*args, **kwargs)
        if len(output) == 1:
            return output[0]
        return output

    return _f


[docs]class BaseBalanceCalculator: """ BaseBalanceCalculator is the low-level interface to calculating balance. BaseBalanceCalculator can be used with any preprocessor defined as a subclass of BaseMatchingPreprocessor. BaseBalanceCalculator implements matrix calculations in pytorch to allow for GPU acceleration. BaseBalanceCalculator performs two main tasks: (1) Computes a per-feature-loss based on the output features of the given preprocessor and (2) Aggregates the per-feature-loss into a single value for the loss. Furthermore, the calculator can compute the loss for many populations at a time. :matching_data: Input matching data to be used for distance calculations. Must contain exactly two populations. The smaller population is used as a reference population. Calls to distance() compute the distance to this reference population. :preprocessor: Preprocessor to use for per-feature-loss calculation. The per-feature-loss is, up to some normalizations, the mean difference in the features at the output of the preprocessor. :feature_weights: How to weight features in aggregation of per-feature-loss. :order: Exponent to use in combining per-feature-loss into an aggregate loss. Total loss is sum(feature_weight * feature_loss**order)**(1/order). :param standardize_difference: Whether to use the absolute standardized mean difference for the per-feature loss (otherwise uses absolute mean difference). :device: Name of device to use for matrix computations. By default, will use GPU if a GPU is found on the system. """ name = "base" def __init__( self, matching_data: MatchingData, preprocessor: BaseMatchingPreprocessor, feature_weights: Optional[Dict[str, float]] = None, order: float = 1, standardize_difference: bool = True, device: Optional[str] = None, ): self.order = order self.standardize_difference = standardize_difference self.device = self._get_device(device) self.preprocessor = preprocessor self.preprocessor.fit(matching_data) self.matching_data = matching_data target, pool = split_target_pool(matching_data) self.target = self._preprocess(target) self.pool = self._preprocess(pool) self._set_feature_weights(feature_weights) self.target_mean = torch.mean(self.target, 0, True).to(self.device) self.target_std = torch.std(self.target, 0, keepdim=True).to(self.device) # Zero variances are bad and can lead to infinite loss. if any((self.target_std == 0)[0]): bad_columns = [ self.preprocessor.output_headers["all"][j] for j, std in enumerate(self.target_std[0]) if std == 0 ] logger.warning( f'Detected constant feature(s) in target population: {",".join(bad_columns)}.' ) pool_std = torch.std(self.pool, 0, keepdim=True).to(self.device) # Zero variances are bad and can lead to infinite loss. if any((pool_std == 0)[0]): bad_columns = [ self.preprocessor.output_headers["all"][j] for j, std in enumerate(pool_std[0]) if std == 0 ] logger.warning( f'Detected constant feature(s) in pool population: {",".join(bad_columns)}.' ) def _set_feature_weights(self, feature_weights): if feature_weights is None: feature_weights = {} try: feature_weights = map_input_output_weights( self.preprocessor, feature_weights ) except NotImplementedError: # User has not passed weights, so we can assume that # NotImplementedError can be ignored. Note that not passing # weights does not mean weights will always be equal!! Equal # weights on the input get diluted on the output, depending on # how many output features an input feature is mapped to. In the # current case, there simply is no mapping available, so we must # put equal weights on the output feature space. feature_weights = defaultdict(lambda: 1) else: # User has explicitly passed weights, so NotImplementedError should # not be handled feature_weights = map_input_output_weights( self.preprocessor, feature_weights ) feature_weights = np.array( [feature_weights[c] for c in self.preprocessor.output_headers["all"]] ) feature_weights = feature_weights / sum(feature_weights) self.feature_weights = torch.tensor(feature_weights).to(self.device) def _get_device(self, device): if device is None: if torch.cuda.is_available(): device = "cuda:0" else: device = "cpu" return torch.device(device) def _preprocess(self, data: Union[pd.DataFrame, MatchingData]): if isinstance(data, pd.DataFrame): data = MatchingData( data=data, headers=self.matching_data.headers, population_col=self.matching_data.population_col, ) data = self.preprocessor.transform(data) data = data[data.headers["all"]].values # It's best to store the feature data on the GPU to avoid moving it back # and forth. Even for large populations, this will be cheap (e.g. # 500,000 patients with 100 features = less than 1GB). Be careful, # however: when slicing for the candidate populations, this can blow up # the memory footprint. Use the BatchedBalanceCalculator to stay within # memory limits. data = data.astype(np.float32) data = torch.tensor( data, dtype=torch.float32, device=self.device, requires_grad=False ) return data def _fetch_features(self, subset_populations, full_population_data): if isinstance(subset_populations, pd.DataFrame): features = self._preprocess(subset_populations) else: if not isinstance(subset_populations, torch.Tensor): subset_populations = torch.tensor( subset_populations, device=self.device, requires_grad=False ) features = full_population_data[subset_populations] # features has shape n_subset_populations x n_patients x n_features # note that using the reshape() method works for both numpy and pytorch if len(features.shape) == 2: features = features.reshape(1, *features.shape) return features def _finalize_batches(self, batches): """ Combine a list (batches) of batched distances calculations into a single pytorch tensor. """ return torch.hstack(batches) def _to_array(self, candidate_populations): return torch.tensor(candidate_populations, device=self.device) def _to_list(self, candidate_populations): return candidate_populations.cpu().detach().numpy().tolist() @reshape_output def distance( self, pool_subsets: Union[pd.DataFrame, torch.Tensor, np.ndarray], target_subsets: Union[pd.DataFrame, torch.Tensor, np.ndarray] = None, ) -> torch.Tensor: """ Compute overall distance (aka "mismatch", aka "loss") between input candidate populations. The per-feature loss is aggregated using a vector norm specified by order specified in __init__(). :pool_subsets: Subsets of the pool population for which to compute the mismatch. Input can be either a single pandas dataframe containing all the required feature data or a 2-dimensional integer array whose entries are the indices of patients in the pool. In the latter case, the array should be an array of shape n_candidate_populations x candidate_population_size that indexes the patient pool (passed during __init__()). The first dimension may be omitted if only one subset population is present. :target_subsets: Subsets of the target population. If an index array, must have same last dimension as pool_subsets. """ per_feature_loss = self.per_feature_loss(pool_subsets, target_subsets) # Since feature_weights sum to 1, sum is actually a weighted mean return torch.sum( self.feature_weights * torch.abs(per_feature_loss) ** self.order, dim=1 ) ** (1.0 / self.order) def balance(self, pool_subsets, target_subsets=None): return -self.distance(pool_subsets, target_subsets)
[docs] def per_feature_loss( self, pool_subsets: Union[pd.DataFrame, torch.Tensor, np.ndarray], target_subsets: Union[pd.DataFrame, torch.Tensor, np.ndarray] = None, ) -> torch.Tensor: """ Compute mismatch (aka "distance", aka "loss") on a per-feature basis for a set of candidate populations. """ pool = self._fetch_features(pool_subsets, self.pool) if target_subsets is not None: target = self._fetch_features(target_subsets, self.target) target_std = target.std(axis=1) target_mean = target.mean(axis=1) if not pool.shape[0] == target.shape[0]: raise ValueError( "Number of subset populations must be same for pool and target!" ) else: target_std = self.target_std target_mean = self.target_mean # NaNs can arise when the pool and target both have zero variance. If # they have zero variance and are the same value (e.g. pool all 0s and # target all 0s), then this should be anyway zero loss. As long as the # norm is non-zero, we are good. If they have zero variance and are # different values (pool all 1s and target all 0s), then this should # represent a large loss. Technically, the loss should be infinite in # that case, but infinities are annoying so we just set the norm to a # small positive number. The user will get warned in the call to # __init__() of the possibility of this arising. norm = 1 if self.standardize_difference: norm *= torch.sqrt(pool.std(axis=1) ** 2 + target_std**2) + 1e-6 return torch.nan_to_num((pool.mean(axis=1) - target_mean) / norm)
[docs]class BetaBalance(BaseBalanceCalculator): """ Convenience interface to BaseBalanceCalculator to computes the distance between populations as the mean standardized mean difference. Uses StandardMatchingPreprocessor as the preprocessor. """ name = "beta" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, drop: bool = "first", standardize_difference: bool = True, ): preprocessor = StandardMatchingPreprocessor(drop=drop) super(BetaBalance, self).__init__( matching_data=matching_data, preprocessor=preprocessor, feature_weights=feature_weights, order=1, standardize_difference=standardize_difference, device=device, )
[docs]class BetaSquaredBalance(BaseBalanceCalculator): """ Same as BetaBalance, except that per-feature balances are averaged in a mean square fashion. """ name = "beta_squared" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, drop: bool = "first", standardize_difference: bool = True, ): preprocessor = StandardMatchingPreprocessor(drop=drop) super(BetaSquaredBalance, self).__init__( matching_data=matching_data, preprocessor=preprocessor, feature_weights=feature_weights, order=2, standardize_difference=standardize_difference, device=device, )
class BetaXBalance(BaseBalanceCalculator): """ Convenience interface to BaseBalanceCalculator to compute the balance between two populations by computing the standardized mean difference, including cross terms. See BetaXPreprocessor for description of preprocessing options. """ name = "beta_x" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, drop: str = "first", standardize_difference: bool = True, max_cross_terms="auto", ): pp_x = BetaXPreprocessor( drop=drop, max_cross_terms=max_cross_terms, ) super(BetaXBalance, self).__init__( matching_data=matching_data, preprocessor=pp_x, order=1, standardize_difference=standardize_difference, device=device, ) class BetaXSquaredBalance(BaseBalanceCalculator): """ Same as BetaXBalance, except that per-feature balances are averages in a mean square fashion. """ name = "beta_x_squared" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, drop: str = "first", standardize_difference: bool = True, max_cross_terms="auto", ): pp_x = BetaXPreprocessor( drop=drop, max_cross_terms=max_cross_terms, ) super(BetaXSquaredBalance, self).__init__( matching_data=matching_data, preprocessor=pp_x, order=2, standardize_difference=standardize_difference, device=device, )
[docs]class BetaMaxBalance(BaseBalanceCalculator): """ Same as BetaBalance, except the worst-matched feature determines the loss. This class is provided as a convenience, since this balance metric is often a criterion used to determine if matching is "sufficiently good". However, be aware that using this balance metric as an optimization objective with the various matchers can lead unwanted behavior, since if improvements in the worst-matched feature are not possible, there is no signal from the balance function to improve any other the other features. """ name = "beta_max" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, drop: bool = "first", standardize_difference: bool = True, ): preprocessor = StandardMatchingPreprocessor(drop=drop) super(BetaMaxBalance, self).__init__( matching_data=matching_data, preprocessor=preprocessor, feature_weights=feature_weights, order=1, standardize_difference=standardize_difference, device=device, ) @reshape_output def distance( self, pool_subsets: Union[pd.DataFrame, torch.Tensor, np.ndarray], target_subsets: Union[pd.DataFrame, torch.Tensor, np.ndarray] = None, ) -> torch.Tensor: per_feature_loss = self.per_feature_loss(pool_subsets, target_subsets) return torch.max( self.feature_weights * torch.abs(per_feature_loss), dim=1 ).values
[docs]class GammaBalance(BaseBalanceCalculator): """ Convenience interface to BaseBalanceCalculator to compute the balance between two populations by computing the mean area between their one-dimensional marginal distributions. See GammaPreprocessor for description of preprocessing options. """ name = "gamma" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, n_bins: int = 5, encode: str = "onehot-dense", cumulative: bool = True, drop: str = "first", standardize_difference: bool = True, ): preprocessor = GammaPreprocessor( n_bins=n_bins, encode=encode, cumulative=cumulative, drop=drop ) super(GammaBalance, self).__init__( matching_data=matching_data, preprocessor=preprocessor, feature_weights=feature_weights, order=1, standardize_difference=standardize_difference, device=device, )
[docs]class GammaSquaredBalance(BaseBalanceCalculator): """ Same as GammaBalance, except that per-feature balances are averages in a mean square fashion. """ name = "gamma_squared" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, n_bins: int = 5, encode: str = "onehot-dense", cumulative: bool = True, drop: str = "first", standardize_difference: bool = True, ): preprocessor = GammaPreprocessor( n_bins=n_bins, encode=encode, cumulative=cumulative, drop=drop ) super(GammaSquaredBalance, self).__init__( matching_data=matching_data, preprocessor=preprocessor, feature_weights=feature_weights, order=2, standardize_difference=standardize_difference, device=device, )
class GammaXBalance(BaseBalanceCalculator): """ Convenience interface to BaseBalanceCalculator to compute the balance between two populations by computing the mean area between their one-dimensional marginal distributions, including cross terms. See GammaXPreprocessor for description of preprocessing options. """ name = "gamma_x" def __init__( self, matching_data: MatchingData, feature_weights: Optional[Dict[str, float]] = None, device: Optional[str] = None, n_bins: int = 5, encode: str = "onehot-dense", cumulative: bool = True, drop: str = "first", standardize_difference: bool = True, max_cross_terms="auto", ): pp_x = GammaXPreprocessor( n_bins=n_bins, encode=encode, cumulative=cumulative, drop=drop, max_cross_terms=max_cross_terms, ) super(GammaXBalance, self).__init__( matching_data=matching_data, preprocessor=pp_x, order=1, standardize_difference=standardize_difference, device=device, )
[docs]class GammaXTreeBalance(BaseBalanceCalculator): name = "gamma_x_tree" def __init__( self, matching_data, keep_original_features=False, device=None, standardize_difference: bool = True, **decision_tree_params, ): pp_tree = DecisionTreeEncoder( keep_original_features=keep_original_features, **decision_tree_params ) super(GammaXTreeBalance, self).__init__( matching_data=matching_data, preprocessor=pp_tree, order=1, standardize_difference=standardize_difference, device=device, )
def _get_batch_size(target_population_size, n_features, max_batch_size_gb=8): """ Get the size of batches for balance calculations such that no batch is greater in memory footprint than max_size_gb GB. This calculation is only approximate and not guaranteed to respect size requirements. n_features should be interpreted as the number of *effective* features (i.e. after binning) """ max_batch_size_gb = max_batch_size_gb / 1.25 # add some wiggle room size_of_float = ( 4 # Balance calculators are supposed to transform to float datatypes! ) # Technically, the size is given by: # batch_size_gb = n_candidate_populations * target_population_size * n_features * size_of_float / 2**30 # But this value will cancel out in the next step. batch_size_gb = target_population_size * n_features * size_of_float / 2**30 batch_size = int(max_batch_size_gb / batch_size_gb) # If you fail this assertion, something is definitely wrong and you'll pay # for it downstream. Rather stop you here. Basically, it means you can't # even hold 10 candidate populations in memory at a time. assert batch_size >= 10 return batch_size
[docs]class BatchedBalanceCaclulator: """ Batch balance calculations to avoid large peak memory usage. """ def __init__(self, balance_calculator, max_batch_size_gb=8): self.name = balance_calculator.name self.matching_data = balance_calculator.matching_data self.preprocessor = balance_calculator.preprocessor self.pool = balance_calculator.pool self.target = balance_calculator.target self.balance_calculator = balance_calculator self.max_batch_size_gb = max_batch_size_gb def _to_array(self, candidate_populations): return self.balance_calculator._to_array(candidate_populations) def _to_list(self, candidate_populations): return self.balance_calculator._to_list(candidate_populations) def distance(self, pool_subsets, target_subsets=None): # If the user passes a pandas DataFrame, the DataFrame is assumed to # represent feature data for only one population and no need for # batching; just pass on to the base class. Otherwise, assume # pool_subsets refers to patient indices. if isinstance(pool_subsets, pd.DataFrame): if not (target_subsets is None or isinstance(target_subsets, pd.DataFrame)): raise ValueError( "target_subsets must be of same datatype as pool_subsets if both are passed" ) return self.balance_calculator.distance(pool_subsets, target_subsets) # If the user passes a list, convert it to the underlying backend array # type for further processing. if isinstance(pool_subsets, list): pool_subsets = self.balance_calculator._to_array(pool_subsets) if isinstance(target_subsets, list): target_subsets = self.balance_calculator._to_array(target_subsets) # If array has only one dimension, no need for batching, just pass along # to base class. if len(pool_subsets.shape) == 1: return self.balance_calculator.distance(pool_subsets, target_subsets) # Get batch size according to number of size of target population batch_size = _get_batch_size( self.balance_calculator.target.shape[0], self.balance_calculator.target.shape[1], self.max_batch_size_gb, ) # From here on, can assume pool_subsets is a 2D array of # patient indices with the 0th dimension corresponding to the # pool_subsets and the 1st dimension corresponding to patient. n_remaining = len(pool_subsets) distances = [] j = 0 while n_remaining > 0: N = min(batch_size, n_remaining) _pool = pool_subsets[j * batch_size : j * batch_size + N, :] if target_subsets is not None: _target = target_subsets[j * batch_size : j * batch_size + N, :] else: _target = None distances.append(self.balance_calculator.distance(_pool, _target)) n_remaining -= N j += 1 distances = self.balance_calculator._finalize_batches(distances) assert len(distances) == len(pool_subsets) return distances def balance(self, pool_subsets, target_subsets=None): return -self.distance(pool_subsets, target_subsets)
# # Convenience interface to balance calculators # BALANCE_CALCULATORS = { BaseBalanceCalculator.name: BaseBalanceCalculator, BetaBalance.name: BetaBalance, BetaXBalance.name: BetaXBalance, BetaXSquaredBalance.name: BetaXSquaredBalance, BetaMaxBalance.name: BetaMaxBalance, BetaSquaredBalance.name: BetaSquaredBalance, GammaBalance.name: GammaBalance, GammaSquaredBalance.name: GammaSquaredBalance, GammaXTreeBalance.name: GammaXTreeBalance, GammaXBalance.name: GammaXBalance, }
[docs]def BalanceCalculator(matching_data, objective="gamma", **kwargs): """ BalanceCalculator provides a convenience interface to balance calculators, allowing the user to initialize a balance calculator by name. The calculators are initialized with default parameters, but these can be overridden by passing the appropriate kwargs. :param matching_data: MatchingData instance containing reference to the data against which matching metrics will be computed :param objective: Name of objective function to be used for computing balance. Balance calculators must be implemented in utils.balance_calculators.py and registered in the BALANCE_CALCULATORS dictionary therein in order to be accessible from this interface. :param kwargs: Any additional arguments required to configure the specific objective function (e.g. n_bins = 10 for "gamma"). """ if objective not in BALANCE_CALCULATORS.keys(): raise ValueError( f"Unknown objective function {objective}. Must be one of {','.join(BALANCE_CALCULATORS.keys())}" ) balance_calculator = BALANCE_CALCULATORS[objective] return balance_calculator(matching_data, **kwargs)