18  Diagnostics for classification models

18.1 Errors for a single model

To examine misclassifications, we can create a separate variable that identifies the errors or not. Constructing this for each class, and exploring in small steps is helpful. Let’s do this using the random forest model for the penguins fit. The random forest fit has only a few misclassifications. There are four Adelie penguins confused with Chinstrap, and similarly four Chinstrap confused with Adelie. There is one Gentoo penguin confused with a Chinstrap. This is interesting, because the Gentoo cluster is well separated from the clusters of the other two penguin species.

Code to fit forest
library(randomForest)
library(dplyr)
load("data/penguins_sub.rda")

penguins_rf <- randomForest(species~.,
                             data=penguins_sub[,1:5],
                             importance=TRUE)

penguins_rf$confusion
          Adelie Chinstrap Gentoo class.error
Adelie       143         2      1 0.020547945
Chinstrap      4        64      0 0.058823529
Gentoo         0         1    118 0.008403361
penguins_errors <- penguins_sub %>%
  mutate(err = ifelse(penguins_rf$predicted !=
                        penguins_rf$y, 1, 0))
Code to make animated gifs
library(tourr)
symbols <- c(1, 16)
p_pch <- symbols[penguins_errors$err+1]
p_cex <- rep(1, length(p_pch))
p_cex[penguins_errors$err==1] <- 2
animate_xy(penguins_errors[,1:4],
           col=penguins_errors$species,
           pch=p_pch, cex=p_cex)
render_gif(penguins_errors[,1:4],
           grand_tour(),
           display_xy(col=penguins_errors$species,
                      pch=p_pch, cex=p_cex),
           gif_file="gifs/p_rf_errors.gif",
           frames=500,
           width=400,
           height=400)

animate_xy(penguins_errors[,1:4],
           guided_tour(lda_pp(penguins_errors$species)),
           col=penguins_errors$species,
           pch=pch)

render_gif(penguins_errors[,1:4],
           guided_tour(lda_pp(penguins_errors$species)),
           display_xy(col=penguins_errors$species,
                      pch=p_pch, cex=p_cex),
           gif_file="gifs/p_rf_errors_guided.gif",
           frames=500,
           width=400,
           height=400,
           loop=FALSE)

shows a grand tour, and a guided tour, of the penguins data, where the misclassifications are marked by an asterisk. (If the gifs are too small to see the different glyphs, you can zoom in to make the figures larger.) It can be seen that the one Gentoo penguin that is mistaken for a Chinstrap by the forest model is always moving with its other Gentoo (yellow) family. It can occasionally be seen to be on the edge of the group, closer to the Chinstraps, in some projections in the grand tour. But in the final projection from the guided tour it is hiding well among the other Gentoos. This is an observation where a mistake has been made because of the inadequacies of the forest algorithm. Forests are only as good as the trees they are constructed from, and we have seen from that the splits only on single variables done by trees does not adequately utilise the covariance structure in each class. They make mistakes based on the boxy nature of the boundaries. This can carry through to the forests model. Even though many trees are combined to generate smoother boundaries, forests do not effectively utilise covariance in clusters either. The other mistakes, where Chinstrap are predicted to be Adelie, and vice versa, are more sensible. These mistaken observations can be seen to lie in the border region between the two clusters, and reflect genuine uncertainty about the classification of penguins in these two species.

FIX ME
(a) Grand tour
FIX ME
(b) Guided tour
Figure 18.1: Examining the misclassified cases (marked as solid circles) from a random forest fit to the penguins data. The one Gentoo penguin mistaken for a Chinstrap is a mistake made because the forest method suffers from the same problems as trees - cutting on single variables rather than effectively using covariance structure. The mistakes between the Adelie and Chinstrap penguins are more sensible because all of these observations lie is the bordering regions between the two clusters.

Some errors are reasonable because there is overlap between the class clusters. Some errors are not reasonable because the model used is inadequate.

18.2 Local explanations

It especially important to be able to interpret or explain a model, even more so when the model is complex or black-box’y. A good resource for learning about the range of methods is Molnar (). Local explanations provide some information about variables that are important for making the prediction for a particular observation. The method that we use here is Shapley value, as computed using the kernelshap package ().

Code
# Split the data intro training and testing, as done in 17-nn chapter
library(dplyr)
library(tidyr)
library(rsample)
library(tidymodels)
library(keras)

load("data/penguins_sub.rda") # from mulgar book

set.seed(821)
p_split <- penguins_sub %>% 
  select(bl:species) %>%
  initial_split(prop = 2/3, 
                strata=species)
p_train <- training(p_split)
p_test <- testing(p_split)

# Data needs to be matrix, and response needs to be numeric
p_train_x <- p_train %>%
  select(bl:bm) %>%
  as.matrix()
p_train_y <- p_train %>% pull(species) %>% as.numeric() 
p_train_y <- p_train_y-1 # Needs to be 0, 1, 2
p_test_x <- p_test %>%
  select(bl:bm) %>%
  as.matrix()
p_test_y <- p_test %>% pull(species) %>% as.numeric() 
p_test_y <- p_test_y-1 # Needs to be 0, 1, 2

# Predict training and test set
p_nn_model <- load_model_tf("data/penguins_cnn")
p_train_pred <- p_nn_model %>% 
  predict(p_train_x, verbose = 0)
p_train_pred_cat <- levels(p_train$species)[
  apply(p_train_pred, 1,
        which.max)]
p_train_pred_cat <- factor(
  p_train_pred_cat,
  levels=levels(p_train$species))

p_test_pred <- p_nn_model %>% 
  predict(p_test_x, verbose = 0)
p_test_pred_cat <- levels(p_test$species)[
  apply(p_test_pred, 1, 
        which.max)]
p_test_pred_cat <- factor(
  p_test_pred_cat,
  levels=levels(p_test$species))
Code
# Explanations
# https://www.r-bloggers.com/2022/08/kernel-shap/
library(kernelshap)
library(shapviz)
p_explain <- kernelshap(
    p_nn_model,
    p_train_x, 
    bg_X = p_train_x,
    verbose = FALSE
  )
p_exp_sv <- shapviz(p_explain)
save(p_exp_sv, file="data/p_exp_sv.rda")

A Shapley value for an observation indicates how the variable contributes to the model prediction for that observation, relative to other variables. It is an average, computed from the change in prediction when all combinations of presence or absence of other variables. In the computation, for each combination, the prediction is computed by substituting absent variables with their average value, like one might do when imputing missing values.

shows the Shapley values for Gentoo observations (both training and test sets) in the penguins data, as a parallel coordinate plot. The values for the single misclassified Gentoo penguin (in the training set) is coloured orange. Overall, the Shapley values don’t vary much on bl, bd and fl but they do on bm. The effect of other variables is seems to be only important for bm.

For the misclassified penguin, the effect of bm for all combinations of other variables leads to a decline in predicted value, thus less confidence in it being a Gentoo. In contrast, for this same penguin when considering the effect of bl the predicted value increases on average.

Code
load("data/p_exp_sv.rda")
p_exp_gentoo <- p_exp_sv$Class_3$S
p_exp_gentoo <- p_exp_gentoo %>%
  as_tibble() %>%
  mutate(species = p_train$species,
         pspecies = p_train_pred_cat,
  ) %>%
  mutate(error = ifelse(species == pspecies, 0, 1))
Code
library(colorspace)
p_exp_gentoo %>%
  filter(species == "Gentoo") %>%
  pivot_longer(bl:bm, names_to="var", values_to="shap") %>%
  mutate(var = factor(var, levels=c("bl", "bd", "fl", "bm"))) %>%
  ggplot(aes(x=var, y=shap, colour=factor(error))) +
  geom_quasirandom(alpha=0.8) +
  scale_colour_discrete_divergingx(palette="Geyser") +
  #facet_wrap(~var) +
  xlab("") + ylab("SHAP") +
  theme_minimal() + 
  theme(legend.position = "none")
Code
library(colorspace)
library(ggpcp)
p_exp_gentoo %>%
  filter(species == "Gentoo") %>%
  pcp_select(1:4) %>%
  ggplot(aes_pcp()) +
    geom_pcp_axes() + 
    geom_pcp_boxes(fill="grey80") + 
    geom_pcp(aes(colour = factor(error)), 
             linewidth = 2, alpha=0.3) +
  scale_colour_discrete_divergingx(palette="Geyser") +
  xlab("") + ylab("SHAP") +
  theme_minimal() + 
  theme(legend.position = "none")
FIXME
Figure 18.2: SHAP values focused on Gentoo class, for each variable. The one misclassified penguin (orange) has a much lower value for body mass, suggesting that this variable is used differently for the prediction than for other penguins.

If we examine the data the explanation makes some sense. The misclassified penguin has an unusually small value on bm. That the SHAP value for bm was quite different pointed to this being a potential issue with the model, particularly for this penguin. This penguin’s prediction is negatively impacted by bm being in the model.

Code
library(patchwork)
# Check position on bm
shap_proj <- p_exp_gentoo %>%
  filter(species == "Gentoo", error == 1) %>%
  select(bl:bm)
shap_proj <- as.matrix(shap_proj/sqrt(sum(shap_proj^2)))
p_exp_gentoo_proj <- p_exp_gentoo %>%
  rename(shap_bl = bl, 
         shap_bd = bd,
         shap_fl = fl, 
         shap_bm = bm) %>%
  bind_cols(as_tibble(p_train_x)) %>%
  mutate(shap1 = shap_proj[1]*bl+
           shap_proj[2]*bd+
           shap_proj[3]*fl+
           shap_proj[4]*bm)
sp1 <- ggplot(p_exp_gentoo_proj, aes(x=bm, y=bl, 
             colour=species, 
             shape=factor(1-error))) +
    geom_point(alpha=0.8) +
  scale_colour_discrete_divergingx(palette="Zissou 1") +
  scale_shape_manual("error", values=c(19, 1)) +
  theme_minimal() + 
  theme(aspect.ratio=1, legend.position="bottom")
sp2 <- ggplot(p_exp_gentoo_proj, aes(x=bm, y=shap1, 
             colour=species, 
             shape=factor(1-error))) +
    geom_point(alpha=0.8) +
  scale_colour_discrete_divergingx(palette="Zissou 1") +
  scale_shape_manual("error", values=c(19, 1)) +
  ylab("SHAP") +
  theme_minimal() + 
  theme(aspect.ratio=1, legend.position="bottom")
sp2 <- ggplot(p_exp_gentoo_proj, aes(x=shap1, 
             fill=species, colour=species)) +
  geom_density(alpha=0.5) +
  geom_vline(xintercept = p_exp_gentoo_proj$shap1[
    p_exp_gentoo_proj$species=="Gentoo" &
    p_exp_gentoo_proj$error==1], colour="black") +
  scale_fill_discrete_divergingx(palette="Zissou 1") +
  scale_colour_discrete_divergingx(palette="Zissou 1") +
  theme_minimal() + 
  theme(aspect.ratio=1, legend.position="bottom")
sp2 <- ggplot(p_exp_gentoo_proj, aes(x=bm, y=bd, 
             colour=species, 
             shape=factor(1-error))) +
    geom_point(alpha=0.8) +
  scale_colour_discrete_divergingx(palette="Zissou 1") +
  scale_shape_manual("error", values=c(19, 1)) +
  theme_minimal() + 
  theme(aspect.ratio=1, legend.position="bottom")
sp1 + sp2 + plot_layout(ncol=2, guides = "collect") &
  theme(legend.position="bottom",
        legend.direction="vertical")
FIXME
Figure 18.3: Plots of the data to help understand what the SHAP values indicate. The misclassified Gentoo penguin has an unusually low body mass value which makes it appear to be more like an Adelie penguin, particularly when considered in relation to it’s bill length.

18.3 Examining boundaries

shows the boundaries for this NN model along with those of the LDA model.

Code
# Generate grid over explanatory variables
p_grid <- tibble(
  bl = runif(10000, min(penguins_sub$bl), max(penguins_sub$bl)),
  bd = runif(10000, min(penguins_sub$bd), max(penguins_sub$bd)),
  fl = runif(10000, min(penguins_sub$fl), max(penguins_sub$fl)),
  bm = runif(10000, min(penguins_sub$bm), max(penguins_sub$bm))
)
# Predict grid
p_grid_pred <- p_nn_model %>%
  predict(as.matrix(p_grid), verbose=0)
p_grid_pred_cat <- levels(p_train$species)[apply(p_grid_pred, 1, which.max)]
p_grid_pred_cat <- factor(p_grid_pred_cat,
                          levels=levels(p_train$species))

# Project into weights from the two nodes
p_grid_proj <- as.matrix(p_grid) %*% p_nn_wgts_on
colnames(p_grid_proj) <- c("nn1", "nn2")
p_grid_proj <- p_grid_proj %>% 
  as_tibble() %>%
  mutate(species = p_grid_pred_cat)

# Plot
ggplot(p_grid_proj, aes(x=nn1, y=nn2, 
                     colour=species)) + 
  geom_point(alpha=0.5) +
  geom_point(data=p_all_m, aes(x=nn1, 
                               y=nn2, 
                               shape=species),
             inherit.aes = FALSE) +
  scale_colour_discrete_divergingx(palette="Zissou 1") +
  scale_shape_manual(values=c(1, 2, 3)) +
  theme_minimal() +
  theme(aspect.ratio=1, 
        legend.position = "bottom",
        legend.title = element_blank())
FIX ME
(a) LDA model
FIX ME
(b) NN model
Figure 18.4: Comparison of the boundaries produced by the LDA (a) and the NN (b) model, using a slice tour.
FIX ME
(a) LDA model
FIX ME
(b) NN model
Figure 18.5: Comparison of the boundaries produced by the LDA (a) and the NN (b) model, using a slice tour.

18.4 Comparison between LDA and CNN

18.5 Constructing data to diagnose your model

18.6 Explainability

Exercises

  1. Examine misclassifications from a random forest model for the fake_trees data between cluster 1 and 0, using the
    1. principal components
    2. votes matrix. Describe where these errors relative to their true and predicted class clusters. When examining the simplex, are the misclassifications the points that are furthest from any vertices?
  2. Examine the misclassifications for the random forest model on the sketches data, focusing on cactus sketches that were mistaken for bananas. Follow up by plotting the images of these errors, and describe whether the classifier is correct that these sketches are so poor their true cactus or banana identity cannot be determined.
  3. How do the errors from the random forest model compare with those of your best fitting CNN model? Are the the corresponding images poor sketches of cacti or bananas?
  4. Now examine the misclassifications of the sketches data in the
    1. votes matrix from the random forest model
    2. predictive probability distribution from the CNN model, using the simplex approach. Are they as expected, points lying in the middle or along an edge of the simplex?