| predict.causal_survival_forest {grf} | R Documentation |
Gets estimates of tau(X) using a trained causal survival forest.
## S3 method for class 'causal_survival_forest' predict( object, newdata = NULL, num.threads = NULL, estimate.variance = FALSE, ... )
object |
The trained forest. |
newdata |
Points at which predictions should be made. If NULL, makes out-of-bag predictions on the training set instead (i.e., provides predictions at Xi using only trees that did not use the i-th training example). Note that this matrix should have the number of columns as the training matrix, and that the columns must appear in the same order. |
num.threads |
Number of threads used in training. If set to NULL, the software automatically selects an appropriate amount. |
estimate.variance |
Whether variance estimates for hattau(x) are desired (for confidence intervals). |
... |
Additional arguments (currently ignored). |
Vector of predictions along with optional variance estimates.
# Train a causal survival forest targeting a Restricted Mean Survival Time (RMST) # with maxium follow-up time set to `horizon`. n <- 2000 p <- 5 X <- matrix(runif(n * p), n, p) W <- rbinom(n, 1, 0.5) horizon <- 1 failure.time <- pmin(rexp(n) * X[, 1] + W, horizon) censor.time <- 2 * runif(n) # Discretizing continuous events decreases runtime. Y <- round(pmin(failure.time, censor.time), 2) D <- as.integer(failure.time <= censor.time) cs.forest <- causal_survival_forest(X, Y, W, D, horizon = horizon) # Predict using the forest. X.test <- matrix(0.5, 10, p) X.test[, 1] <- seq(0, 1, length.out = 10) cs.pred <- predict(cs.forest, X.test) # Predict on out-of-bag training samples. cs.pred <- predict(cs.forest) # Compute a doubly robust estimate of the average treatment effect. average_treatment_effect(cs.forest) # Compute the best linear projection on the first covariate. best_linear_projection(cs.forest, X[, 1]) # Train a causal survival forest targeting an absolute risk difference # at the median timepoint `horizon`. cs.forest.prob <- causal_survival_forest(X, Y, W, D, target = "survival.probability", horizon = 0.5)