Paper Summary: Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift
Last updated:Please note This post is mainly intended for my personal use. It is not peer-reviewed work and should not be taken as such.
WHAT
Authors discuss and compare ways to detect when samples being scored by an ML system start to deviate from the distribution it was trained on.
This is called Data drift.
Detecting drift
Their approach to detect data drift combines 2 steps:
1) project both train-time and inference-time samples into a smaller space using some kind of dimensionality reduction (from simple like PCA and Autoencoders to more fancy stuff)
2) use a two-sample test on the reduced representations of both groups. Tests used include MMD, KS Test and Chi-Squared Test
Measuring the effect of the drift on models
Once it is established we have a drift, authors suggest the following approach to measure the malignancy of the drift.
In other words, they want to measure what magnitude of bad predictions will the drift cause. There are 3 steps:
1) train a simple model (called a domain classifier) to discriminate between train-time and inference-time samples
2) select a small number of samples in inference-time that are the most different from training-time (i.e. those samples that the domain classifier has given the highest probability of belonging to the inference-time)
3) manually analyze these few samples and measure how good (bad) the original ML model predictions are on them
CLAIMS
The best Dimensionality reduction technique to detect drift is BBSD (Black-box Shift Detection)
- This is an approach introduced by a similar group of authors in Lipton et al. 2018 (see below)
- The rough idea is that you take an arbitrarily pretrained neural net, feed the samples to it and use the outputs as the reduced dimensionality representation
Univariate two-sample tests offer comparable performance to multivariate ones.
QUOTES
On approaches that treat shift detection as a form of anomaly detection: "Several recent papers have proposed outlier detection mechanisms dubbing the task out-of distribution (OOD) sample detection"
"[shifts in] target data exhibiting only 10% anomalous samples are hard to detect, suggesting that this setting might be better addressed via outlier detection"
NOTES
Two-sample testing is done using training data vs actual (current) data
It is assumed that shifts are non-adversarial, i.e. caused by natural changes in distributions rather than intentional production of pathological examples
They test their methods on all sorts of artificially-induced data drift, such as:
- Making all samples of a single class
- Adding noise to one class, both classes
- Image-specific noise
Authors mention that there exist two types of dataset drift, namely:
- Data shift/drift, where the distribution of features across all classes change over time
- Label shift/drift, where the label distribution (\p(y)\)) changes over time, but \(p(x|y)\) (distribution of features within the same class) doesn't
MY 2¢
Having a ground-truth (manually labelled) for samples output by the domain classifier is not very realistic in production systems
If we can retrain models periodically, data drift become less important
References
Rabanser et al. 2019: Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift
Lipton et al. 2018: Detecting and Correcting for Label Shift with Black Box Predictors
- This is an earlier paper on detecting label shift and automatically correcting classifiers to account for the drift
- The focus here is on label shift, meaning that \(p(y)\) changes (samples of a given type become more/less frequent) but the distribution of \(p(x|y)\) doesn't