Everyone loves random forests (RFs). The algorithm is powerful, intuitive, and easily implemented in most languages and statistical platforms. However, a common critique is that it is difficult to extrapolate substantive effects from a trained random forest. This is especially true amongst the academic crowd who are used to coefficients and magical ***’s. Alas, R’s randomForest package and Python’s sklearn both provide “partial plot” methods, which demonstrate the substantive effects on an independent variable (IV) on the dependent variable (DV). But, there are two key issues with the resulting plots. First, they are ugly. Second, they do not provide confidence intervals (CI’s), which can make interpretation difficult relative to GLM-based marginal effects plots.
Taken directly from the randomForest partialPlot documentation: “The function being plotted is defined as (the function below) where is the variable for which partial dependence is sought and is the other variables in the data.”
Below is a crude representation for how the partialPlot function works. Consider the data below (Y=dependent variable, var_1, var_2, and var_3 are independent variables):
Step 1: Train a random forest on the data.
Step 2: Choose an independent variable of interest. For this example, let’s choose var_1.
Step 3: Build a list out of unique values in var_1. For this example, list = [4,6,16,18].
Step 4: Start with the first element in the list (4, in this case), and set all values in the column chosen in Step 2 to that element. Then, use the trained algorithm object from Step 1, and build a prediction for each row. The table below illustrates the output of Step 4. Note that all values of var_1 are set to the first element in the list for the first iteration.
Persist the element value (4, in this case) and the mean of the predictions.
Step 6: Iterate through all elements in the list, and then plot. In the plot, the list values are x coordinates, the mean of the predictions form the y coordinates.
Although this approach can uncover basic trends in the marginal effects of an independent variable of interest, it prevents a user from determining the variance of the suggested effects. This can be problematic for all of the reasons that ignoring variance around predictions can be problematic.
The R code at the bottom of the page provides a function called plot_partial, which adds confidence interval functionality to randomForest partialPlots. Essentially, it just removes the portion of the function, and leverages the full vector of predictions.
First, let’s examine the output of the partialPlot function, first with a continuos IV, then with a categorical IV.
>airquality <- na.omit(airquality)
>rf_1 <- randomForest(Ozone ~ ., airquality)
>partialPlot(rf_1, airquality, Temp)
>rf_iris<-randomForest(Sepal.Length ~., iris)
>partialPlot(rf_iris, iris, Species)
No need to belabor the point, but these plots are ugly and ignore prediction variance. Here are the same partial plots, only this time using the partial_plot function provided in the code at the bottom of the page.
> plot_partial(rf=rf_1, data=airquality, dv=”Ozone”, iv=”Temp”, conf_int_lb=.25, conf_int_ub=.75)
>plot_partial(rf_iris, iris, dv=”Sepal.Length”, iv=”Species”, conf_int_lb=.27, conf_int_ub=.75)
Arguments for the partial_plot function:
rf = a trained random forest object
data = a data frame in the same feature space as the data frame used to train the random forest object. Often, this will be the same data set used in training.
dv = the name of the dependent variable, provide in ” ” .
iv = the name of the independent variable of interest, provided in ” ” .
conf_int_lb = the percentile of predictions to plot as the lower bound.
conf_int_ub = the percentile of predictions to plot as the upper bound.
num_sample = equivalent of n.pt. If iv is continuous, the number of values to choose for evaluating partial dependence. Default = NULL.
delta = whether to plot the change in prediction between the actual iv values and the simulated values. Default = NULL.
|1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85||