Paper Summary: Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift

Last updated:

Please note This post is mainly intended for my personal use. It is not peer-reviewed work and should not be taken as such.

WHAT

Authors discuss and compare ways to detect when samples being scored by an ML system start to deviate from the distribution it was trained on.

This is called Data drift.

Twitter Linkedin YC Hacker News Reddit

Detecting drift

Their approach to detect data drift combines 2 steps:

  • 1) project both train-time and inference-time samples into a smaller space using some kind of dimensionality reduction (from simple like PCA and Autoencoders to more fancy stuff)

  • 2) use a two-sample test on the reduced representations of both groups. Tests used include MMD, KS Test and Chi-Squared Test

Measuring the effect of the drift on models

Once it is established we have a drift, authors suggest the following approach to measure the malignancy of the drift.

In other words, they want to measure what magnitude of bad predictions will the drift cause. There are 3 steps:

1) train a simple model (called a domain classifier) to discriminate between train-time and inference-time samples

2) select a small number of samples in inference-time that are the most different from training-time (i.e. those samples that the domain classifier has given the highest probability of belonging to the inference-time)

3) manually analyze these few samples and measure how good (bad) the original ML model predictions are on them

CLAIMS

  • The best Dimensionality reduction technique to detect drift is BBSD (Black-box Shift Detection)

    • This is an approach introduced by a similar group of authors in Lipton et al. 2018 (see below)
    • The rough idea is that you take an arbitrarily pretrained neural net, feed the samples to it and use the outputs as the reduced dimensionality representation
  • Univariate two-sample tests offer comparable performance to multivariate ones.

QUOTES

  • On approaches that treat shift detection as a form of anomaly detection: "Several recent papers have proposed outlier detection mechanisms dubbing the task out-of distribution (OOD) sample detection"

  • "[shifts in] target data exhibiting only 10% anomalous samples are hard to detect, suggesting that this setting might be better addressed via outlier detection"

Twitter Linkedin YC Hacker News Reddit

NOTES

  • Two-sample testing is done using training data vs actual (current) data

  • It is assumed that shifts are non-adversarial, i.e. caused by natural changes in distributions rather than intentional production of pathological examples

  • They test their methods on all sorts of artificially-induced data drift, such as:

    • Making all samples of a single class
    • Adding noise to one class, both classes
    • Image-specific noise
  • Authors mention that there exist two types of dataset drift, namely:

    • Data shift/drift, where the distribution of features across all classes change over time
    • Label shift/drift, where the label distribution (\p(y)\)) changes over time, but \(p(x|y)\) (distribution of features within the same class) doesn't

MY 2¢

  • Having a ground-truth (manually labelled) for samples output by the domain classifier is not very realistic in production systems

  • If we can retrain models periodically, data drift become less important


References

Dialogue & Discussion