Causal Tree
Causal Tree is a data-driven approach to partition the data into subpopulations which differ in the magnitude of their causal effects [Athey2015]. This method is applicable when the unconfoundness is satisfied given the adjustment set (covariate) \(V\). The interested causal effects is the CATE:
Due to the fact that the counterfactuals can never be observed, [Athey2015] developed an honest approach where the loss function (criterion for building a tree) is designed as
where \(N_{tr}\) is the number of samples in the training set \(S_{tr}\), \(p\) is the ratio of the number of samples in the treat group to that of the control group in the training set, and
Example
import numpy as np
import matplotlib.pyplot as plt
from ylearn.estimator_model.causal_tree import CausalTree
from ylearn.exp_dataset.exp_data import sq_data
from ylearn.utils._common import to_df
# build dataset
n = 2000
d = 10
n_x = 1
y, x, v = sq_data(n, d, n_x)
true_te = lambda X: np.hstack([X[:, [0]]**2 + 1, np.ones((X.shape[0], n_x - 1))])
data = to_df(treatment=x, outcome=y, v=v)
outcome = 'outcome'
treatment = 'treatment'
adjustment = data.columns[2:]
# build test data
v_test = v[:min(100, n)].copy()
v_test[:, 0] = np.linspace(np.percentile(v[:, 0], 1), np.percentile(v[:, 0], 99), min(100, n))
test_data = to_df(v=v_test)
Train the CausalTree and use it in the test data:
ct = CausalTree(min_samples_leaf=3, max_depth=5)
ct.fit(data=data, outcome=outcome, treatment=treatment, adjustment=adjustment)
ct_pred = ct.estimate(data=test_data)
Class Structures
- class ylearn.estimator_model.causal_tree.CausalTree(*, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=2022, max_leaf_nodes=None, max_features=None, min_impurity_decrease=0.0, min_weight_fraction_leaf=0.0, ccp_alpha=0.0, categories='auto')
- Parameters
splitter ({"best", "random"}, default="best") – The strategy used to choose the split at each node. Supported strategies are “best” to choose the best split and “random” to choose the best random split.
max_depth (int, default=None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
min_samples_split (int or float, default=2) –
The minimum number of samples required to split an internal node:
If int, then consider min_samples_split as the minimum number.
If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.
min_samples_leaf (int or float, default=1) –
The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least
min_samples_leaftraining samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.If int, then consider min_samples_leaf as the minimum number.
If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.
min_weight_fraction_leaf (float, default=0.0) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.
max_features (int, float or {"sqrt", "log2"}, default=None) –
The number of features to consider when looking for the best split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features.
random_state (int) – Controls the randomness of the estimator.
max_leaf_nodes (int, default to None) – Grow a tree with
max_leaf_nodesin best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.min_impurity_decrease (float, default=0.0) –
A node will be split if this split induces a decrease of the impurity greater than or equal to this value. The weighted impurity decrease equation is the following
N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)
where
Nis the total number of samples,N_tis the number of samples at the current node,N_t_Lis the number of samples in the left child, andN_t_Ris the number of samples in the right child.N,N_t,N_t_RandN_t_Lall refer to the weighted sum, ifsample_weightis passed.categories (str, optional, default='auto') –
- fit(data, outcome, treatment, adjustment=None, covariate=None, treat=None, control=None)
Fit the model on data to estimate the causal effect.
- Parameters
data (pandas.DataFrame) – The input samples for the est_model to estimate the causal effects and for the CEInterpreter to fit.
outcome (list of str, optional) – Names of the outcomes.
treatment (list of str, optional) – Names of the treatments.
covariate (list of str, optional, default=None) – Names of the covariate vectors.
adjustment (list of str, optional, default=None) – Names of the covariate vectors. Note that we may only need the covariate set, which usually is a subset of the adjustment set.
treat (int or list, optional, default=None) –
If there is only one discrete treatment, then treat indicates the treatment group. If there are multiple treatment groups, then treat should be a list of str with length equal to the number of treatments. For example, when there are multiple discrete treatments,
array([‘run’, ‘read’])
means the treat value of the first treatment is taken as ‘run’ and that of the second treatment is taken as ‘read’.
control (int or list, optional, default=None) – See treat.
- Returns
Fitted CausalTree
- Return type
instance of CausalTree
- estimate(data=None, quantity=None)
Estimate the causal effect of the treatment on the outcome in data.
- Parameters
data (pandas.DataFrame, optional, default=None) – If None, data will be set as the training data.
quantity (str, optional, default=None) –
Option for returned estimation result. The possible values of quantity include:
’CATE’ : the estimator will evaluate the CATE;
’ATE’ : the estimator will evaluate the ATE;
None : the estimator will evaluate the ITE or CITE.
- Returns
The estimated causal effect with the type of the quantity.
- Return type
ndarray or float, optional
- plot_causal_tree(feature_names=None, max_depth=None, class_names=None, label='all', filled=False, node_ids=False, proportion=False, rounded=False, precision=3, ax=None, fontsize=None)
Plot a policy tree. The sample counts that are shown are weighted with any sample_weights that might be present. The visualization is fit automatically to the size of the axis. Use the
figsizeordpiarguments ofplt.figureto control the size of the rendering.- Returns
List containing the artists for the annotation boxes making up the tree.
- Return type
annotations : list of artists
- decision_path(*, data=None, wv=None)
Return the decision path.
- Parameters
wv (numpy.ndarray, default=None) – The input samples as an ndarray. If None, then the DataFrame data will be used as the input samples.
data (pandas.DataFrame, default=None) – The input samples. The data must contains columns of the covariates used for training the model. If None, the training data will be passed as input samples.
- Returns
Return a node indicator CSR matrix where non zero elements indicates that the samples goes through the nodes.
- Return type
indicator : sparse matrix of shape (n_samples, n_nodes)
- apply(*, data=None, wv=None)
Return the index of the leaf that each sample is predicted as.
- Parameters
wv (numpy.ndarray, default=None) – The input samples as an ndarray. If None, then the DataFrame data will be used as the input samples.
data (pandas.DataFrame, default=None) – The input samples. The data must contains columns of the covariates used for training the model. If None, the training data will be passed as input samples.
- Returns
For each datapoint v_i in v, return the index of the leaf v_i ends up in. Leaves are numbered within
[0; self.tree_.node_count), possibly with gaps in the numbering.- Return type
v_leaves : array-like of shape (n_samples, )
- property feature_importance
- Returns
Normalized total reduction of criteria by feature (Gini importance).
- Return type
ndarray of shape (n_features,)