| predict.causal_survival_forest {grf} | R Documentation |
Gets estimates of tau(X) using a trained causal survival forest.
## S3 method for class 'causal_survival_forest' predict( object, newdata = NULL, num.threads = NULL, estimate.variance = FALSE, ... )
object |
The trained forest. |
newdata |
Points at which predictions should be made. If NULL, makes out-of-bag predictions on the training set instead (i.e., provides predictions at Xi using only trees that did not use the i-th training example). Note that this matrix should have the number of columns as the training matrix, and that the columns must appear in the same order. |
num.threads |
Number of threads used in training. If set to NULL, the software automatically selects an appropriate amount. |
estimate.variance |
Whether variance estimates for hattau(x) are desired (for confidence intervals). |
... |
Additional arguments (currently ignored). |
Vector of predictions.
# Train a standard causal survival forest.
n <- 3000
p <- 5
X <- matrix(runif(n * p), n, p)
W <- rbinom(n, 1, 0.5)
Y.max <- 1
failure.time <- pmin(rexp(n) * X[, 1] + W, Y.max)
censor.time <- 2 * runif(n)
Y <- pmin(failure.time, censor.time)
D <- as.integer(failure.time <= censor.time)
cs.forest <- causal_survival_forest(X, Y, W, D)
# Predict using the forest.
X.test <- matrix(0.5, 10, p)
X.test[, 1] <- seq(0, 1, length.out = 10)
cs.pred <- predict(cs.forest, X.test, estimate.variance = TRUE)
# Plot the estimated CATEs along with 95% confidence bands.
r.monte.carlo <- rexp(5000)
cate <- rep(NA, 10)
for (i in 1:10) {
cate[i] <- mean(pmin(r.monte.carlo * X.test[i, 1] + 1, Y.max) -
pmin(r.monte.carlo * X.test[i, 1], Y.max))
}
plot(X.test[, 1], cate, type = 'l', col = 'red')
points(X.test[, 1], cs.pred$predictions)
lines(X.test[, 1], cs.pred$predictions + 2 * sqrt(cs.pred$variance.estimates), lty = 2)
lines(X.test[, 1], cs.pred$predictions - 2 * sqrt(cs.pred$variance.estimates), lty = 2)
# Compute a doubly robust estimate of the average treatment effect.
average_treatment_effect(cs.forest)
# Compute the best linear projection on the first covariate.
best_linear_projection(cs.forest, X[, 1])
# Train the forest on a less granular grid.
cs.forest.grid <- causal_survival_forest(X, Y, W, D,
failure.times = seq(min(Y), max(Y), length.out = 50))
plot(X.test[, 1], cs.pred$predictions)
points(X.test[, 1], predict(cs.forest.grid, X.test)$predictions, col = "blue")