forked from vanderbilt-data-science/ancient-artifacts
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path50-reporting.Rmd
More file actions
380 lines (308 loc) · 13.8 KB
/
Copy path50-reporting.Rmd
File metadata and controls
380 lines (308 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
---
title: "50-reporting"
output: html_document
---
This file contains general functions to be applied across several model types for interpretation.
```{r required imports}
pacman::p_load(yardstick, precrec, vip,
glmnet, naivebayes, xgboost, ranger)
```
# Prediction with defined outcomes
```{r predict using fit on selected data}
get_prediction_dataframes <- function(final_fit, pred_data, ...) {
#returns class and probability predictions on pred_data using final_fit
#Inputs: final_fit: finalized workflow fit from tidymodels
# pred_data: tibble to be used for inference (prediction)
# ...: additional parameters for predict
#Outputs: dataframe with .pred_class, and .pred_[class] mutated
#get prediction class and probabilities
class_pred_df <-
predict(final_fit, pred_data, ...) %>%
bind_cols(predict(final_fit, pred_data, type = "prob")) %>%
bind_cols(pred_data %>%
select(particle_class))
return (class_pred_df)
}
```
# Performance metric analysis
```{r calculate and print best xval model metrics}
calculate_best_performance_metrics <- function (fold_metrics, best_params, mdl_name=NULL) {
#Returns cross-validation metrics for best model (hyperparameters)
#Inputs: fold_metrics: selected and unnested tibble of cross-validation metrics from tidymodels
# best_params: best parameters as selected by tidymodels
# mdl_name (default NULL): String of model name to be used for plotting
#Returns: dataframe of best performance metrics
#Outputs: plot of performance metrics of cross validation folds
#add name to model if desired
if(is.null(mdl_name)){
mdl_name = 'selected model'
}
#get metrics
best_fold_metrics <- fold_metrics %>%
filter(.config==best_params$.config[[1]])
#plot
print(best_fold_metrics %>%
mutate(facet_val = if_else(.metric== 'roc_auc' | .metric=='pr_auc' | .metric=='f_meas', 'Aggregate metrics', 'Confusion matrix metrics')) %>%
ggplot(aes(x=.metric, y=.estimate, fill=.metric)) +
geom_boxplot(outlier.shape = NA, na.rm=TRUE) +
geom_jitter(aes(x=.metric, y=.estimate), na.rm=TRUE) +
facet_grid(cols=vars(facet_val), scales='free') + #just to get on separate plots
labs(title='Distribution of cross validation metrics for best hyperparameter set',
subtitle=str_c('By metric, for ', mdl_name),
x='metric',
y='metric estimate') +
theme(legend.position = "none"))
return (best_fold_metrics)
}
```
```{r generate confusion matrix stats}
calculate_confusion_matrix <-function (pred_frame) {
#Returns the confusion matrix for a given frame containing prediction probabilities and class
#Inputs: pred_frame: tibble of predictions with .pred_class and target class column named "particle_class"
#Returns: confusion matrix tibble
#Outputs: plot of confusion matrix with summarizing metrics calculated from it
#calculate confusion matrix
pred_conf <- pred_frame %>%
conf_mat(particle_class, .pred_class)
#get summary info
t1 <- pred_conf %>%
summary() %>%
select(-.estimator) %>%
gridExtra::tableGrob(rows=NULL, theme=gridExtra::ttheme_default(base_size=10))
#plot cmat info
cm <- pred_conf %>%
autoplot(type='heatmap') +
labs(title='Confusion matrix for training data')
gridExtra::grid.arrange(cm, t1, ncol=2)
return (t1)
}
```
# Plot prediction data frame
```{r plot the probabilities for predicitons}
plot_prediction_probabilities <- function(pred_frame){
#Returns a plot of prediction probabilities for gaining insight into classification behavior
#Inputs: pred_frame: tibble of predictions where where .pred_site is column name of interest contained in the df
#Outputs: plot of prediction probability view
new_training_preds <- filter(pred_frame, .pred_site < 0.8)
alpha <- ifelse(pred_frame$particle_class == "exp", 0.4, 0.1)
ggplot(data = pred_frame, mapping = aes(x = .pred_class , y =.pred_site, color = particle_class) ) +
geom_jitter(width = 0.4, size = 1, alpha = alpha ) +
xlab("Model Predicted Class") + ylab("Probability of Site") +
labs(title="Probability Prediction Plots")
}
```
# Calibration curve
```{r plot a calibration curve based on prediction}
plot_calibration_curve <- function(pred_frame, mdl_name=NULL){
#Returns a plot of the calibration of a model given the prediction dataframe
#Inputs: pred_frame: tibble of prediction dataframe with target class column named "particle_class" and
# target probabilities ".pred_exp"
# mdl_name (default NULL): String of model name to be used during plotting
#Outputs: plot of calibration curve
#add compatibility for plotting many models
if(!('mdl_name' %in% names(pred_frame))){
#if you don't provide a value, you get the default
if(is.null(mdl_name)){
mdl_name <- 'mdl'
}
#otherwise, you can specifically name a single model
pred_frame <- pred_frame %>%
mutate(mdl_name = mdl_name)
}
#get training pred labels in a format compatible with geom_smooth
new_training_preds <- as.data.frame(pred_frame)
new_training_preds$particle_class <- as.numeric(new_training_preds$particle_class)
new_training_preds[, 4][new_training_preds[, 4] == 2] <-0
#plot calibration curve
ggplot(data = new_training_preds, mapping = aes(x = .pred_exp , y = particle_class) ) +
scale_y_continuous(limits = c(0, 1), breaks = seq(0, 1, by = 0.2)) +
scale_x_continuous(limits = c(0, 1), breaks = seq(0, 1, by = 0.2)) +
geom_smooth(aes(x = .pred_exp, y = particle_class, color=mdl_name ), se = F, method = "loess") +
geom_abline() +
xlab("Model Probability of Experimental") + ylab("True Particle Class (Site [0] to Experimental[1]") +
labs(title="Calibration Curve", color='Model Type')
}
```
# Probability-based curves
```{r function for plotting ordered probability and class relation}
plot_label_by_score <- function(preds_df){
#Function plots ordered probability with actual class as coloring for visual inspection of misclassification
#preds_df: prediction dataframe with minimally .pred_exp and particle_class columns
#Outputs: plot of label by score
preds_df %>%
arrange(.pred_exp)%>%
mutate(prob_order = factor(1:nrow(.)))%>%
ggplot(aes(x=1:nrow(preds_df), y=.pred_exp, color=particle_class))+
geom_point(size=1, alpha=0.6) +
labs(x='Sorted order of scores',
y='Scores',
title='Actual class label based on increasing score')
}
```
```{r function for viewing misclassification by score}
plot_misclassification_rates <- function(preds_df){
#Function plots misclassification rate by score bin
#preds_df: prediction dataframe with minimally .pred_exp and particle_class columns
#outputs: plot of misclassification rates
preds_df %>%
mutate(is_wrong = (.pred_class != particle_class)) %>%
mutate(bin_start = str_match(cut_interval(.$.pred_exp, n=20), "[\\(\\[]([\\d\\.]+).+]")[,2]) %>%
mutate(is_fp = ifelse(bin_start<0.5, "fn", "fp")) %>%
group_by(bin_start) %>%
summarise(bin_counts=sum(is_wrong)/n(), error_type=unique(is_fp)) %>%
ggplot(aes(x=bin_start, y=bin_counts, fill=error_type)) +
geom_col() +
labs(x='score interval', y='Percentage misclassified in bin', title='Misclassification rate in probability interval') +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
}
```
```{r function for ROC curves}
plot_performance_curves <- function(pred_dfs, model_names=NULL, pos_class='exp'){
#Function plots ROC and PR AUC curves for one or more models and returns the eval information
#preds_df: a single tibble or list of tibbles containing prediction info (.pred_exp and particle_class at least)
#model_names: VECTOR of model names corresponding to order of preds_df
#pos_class (default 'exp'): positive class or class of interest
#Returns: table of AUC/PRAUC scores
#Outputs: ROC and PR curve
#if passed in a single dataframe for plotting, make into list
if('tbl_df' %in% class(pred_dfs))
pred_dfs<-list(pred_dfs)
#get list of scores for plotting
scores_list <- pred_dfs %>%
map(~select(., .pred_exp)) %>%
join_scores()
#get list of labels for plotting
labels_list <- pred_dfs %>%
map(~select(., particle_class)) %>%
join_labels()
#fix model names if necessary
if(is.null(model_names)){
model_names <- str_c('Model', 1:length(pred_dfs))
}
#calculate model eval and autoplot
model_eval <- evalmod(scores = scores_list, labels = labels_list, modnames = model_names, posclass = pos_class)
autoplot(model_eval)
return(model_eval)
}
```
# Variable Importance
## Permutation importance
The permutation importance is a model-agnostic measure of variable importance. Model-specific measures of variable importance are also available below.
```{r generic plot for variable importance}
plot_variable_importance <- function(wkflw, assessment_data, mdl_name=NULL, plot_type='default',
p_metric='auc', target='particle_class', positive_class='exp', smaller_is_better=NULL){
#Function plot_variable_importance plots permutation importance
#Inputs: wkflw: fitted final fit (workflow object)
# assessment_data: raw tibble to be processed with recipe and then predicted on
# mdl_name (default NULL): string of model name or type
# plot_type (default 'default'): 'abs' plots and orders by absolute importance, 'default' plots by signed importance
# p_metric (default 'auc'): string of metric to use for comparing permutation importance
# target (default 'particle_class'): string name of column in newdata to be predicted (outcome variable)
# positive_class (default 'exp'): string name of reference class or class of interest
# smaller_is_better (default NULL): boolean of whether the p_metric is better when a higher value (e.g., AUC) or lower (e.g., RMSE)
#Output: prints variable importance based on plot_type
#Returns: named list of permimp (tibble of permutation importance), plt (ggplot object of plot_type)
#a bit of input validation
if(is.null(mdl_name)){
mdl_name <- 'selected model'
}
if(plot_type!='default' & plot_type!='abs')
stop("parameter 'plot_type' must be 'abs' or 'default'.")
#the following operations on wkflw and data are a workaround
#since vi_permute on workflows isn't currently implemented
#get fit
mdl <- wkflw
mdl_recipe <- pull_workflow_prepped_recipe(mdl)
pred_wrapper<-function(object, newdata, ref_cls = positive_class){
#message('here')
#print(colnames(new_data))
#res <- predict(object, new_data=newdata, type='class')
#pred_vec <- ifelse(res$.pred_class==ref_cls, yes=1, no=0)
res <- predict(object, new_data=newdata, type='prob')
pred_vec <- res %>%
select(contains(ref_cls)) %>%
pull()
return (pred_vec)
}
#get permutation importance
vip_res <- vi_permute(object=mdl, train=assessment_data, target=target,
metric=p_metric, pred_wrapper=pred_wrapper, reference_class=positive_class,
smaller_is_better=smaller_is_better, new_data=assessment_data)
#drop rows which aren't modeled
keep_cols <- mdl_recipe$var_info %>%
filter(role=='predictor' | role=='outcome') %>%
pull(variable)
vip_res <- vip_res %>%
filter(Variable %in% keep_cols)
#determine plot base
if(plot_type=='default'){
g <- vip_res %>%
ggplot(aes(x=fct_reorder(Variable, Importance), y=Importance)) +
geom_col() +
coord_flip()
} else {
g<- vip_res %>%
ggplot(aes(x=fct_reorder(Variable, abs(Importance)), y=abs(Importance), fill=permute_sign)) +
geom_col() +
coord_flip()
}
#plot rest
g <- g + labs(title = str_c('Permutation importance of features in ', mdl_name),
subtitle = 'Change is (baseline - permuted)',
x = 'Feature',
y = str_c('Change in ', p_metric, ' due to permutation'))
print(g)
#return both
return(list(permimp = vip_res,
plt = g ))
}
```
## Model-based variable importance
```{r model based variable importance helper}
get_vip <- function(final_fit, ...){
#helper function to model_variable_importance to calculate importance (should not be called directly)
#Inputs: final_fit: tidymodels workflow fit
# ...: additional parameters for vi_model (from vip)
#Outputs: tibble of variable importance
model_vip <- final_fit %>%
pull_workflow_fit() %>%
vi_model(...)
if('Sign' %in% colnames(model_vip)){
model_vip <- model_vip %>%
mutate(association = if_else(Sign=='NEG', 'exp', 'site'))
} else {
model_vip <- model_vip %>%
mutate(association = 'mdl')
}
return(model_vip)
}
```
```{r generalized model based variable importance}
model_variable_importance <-function(final_fit) {
#returns model-specific variable importance with plot
#Inputs: final_fit: tidymodels workflow fit
#Returns: tibble of variable importance
#Outputs: shows plot of variable importance
#get the model type because vi_model requires specific parameters
model_type <- class(pull_workflow_spec(final_fit))[[1]]
#call vi_model based on model type
if(model_type=='logistic_reg'){
model_vip <- get_vip(final_fit, lambda = final_fit$fit$fit$spec$args$penalty)
} else if (model_type=='rand_forest' | model_type=='boost_tree'){
model_vip <- get_vip(final_fit)
} else {
stop('Model-based variable importance only supported for logistic regression, random forest, and boost models.')
}
#plot
plt <- model_vip %>%
ggplot(aes(x=fct_reorder(Variable, Importance), y=Importance, fill=association))+
geom_col() +
coord_flip() +
labs(title='Model-based variable importance',
subtitle=str_c('for model type ', model_type),
y='model-based importance',
x='variable')
print(plt)
return(model_vip)
}
```