Robust statistics
Robust statistics is the problem of estimating parameters from unreliable empirical data. Typically, suppose that a fraction of the training dataset is compromised. Can we design algorithms that nevertheless succeed in learning adequately from such a partially compromised dataset?
This question has arguably become crucial, as large-scale algorithms perform learning from users' data. Yet, clearly, if the algorithm is used by thousands, millions or billions of users, many of the data will likely be corrupted, because of bugs Boddy16, or because some users will maliciously want to exploit or attack the algorithm. This latter case is known as a poisoning attack.
In very high dimensional, designing efficient algorithms with strong robustness guarantees is an open problem, though there are fascinating recent results, both for classical statistical tasks DiakonikolasKane19 and neural networks BMGS17.
Contents
Example of the median
Suppose the data are real numbers, and we want to estimate the mean of the true data (which we shall call inliers). Note that the naive empirical mean estimate would be a bad idea here, as a single malicious user could completely upset the empirical mean estimate. In fact, by choosing its input data adequately (called outliers), the malicious user can make the empirical mean estimate equal whatever the malicious user wants it to be.
It turns out that using the median of the dataset is a robust way to do so. Indeed, even if 45% of the data are outliers, the median will still be a quantile of the inliers, which should not be too far from the actual mean. The median is said to have a 0.5 statistical breakdown point RousseeuwLeroy05. No statistical method can achieve a better breakdown point, but other methods also achieve 0.5 statistical breakdown, like trimmed mean (we remove sufficiently many extreme values on both sides and take the mean of the rest).
Another way to quantify robustness is to compute a high-probability upper bound between the empirical median and the mean μ of the true distribution of inliers. Call ε the fraction of outliers. It turns out that, assuming the true distribution is a normal distribution N(μ,1), given n=Ω((d+log(1/τ))/ε2) data, we can guarantee |mean-μ| < O(ε) with probability 1-τ. This asymptotic bound is also best possible Huber92.
Poisoning models
The above model holds for arguably the strongest poisoning model. This is one where an adversary gets to read the full dataset before we can, and is able to erase a fraction ε of the data, and to replace them by any imaginable input. The dataset is then analyzed by our (robust) statistics algorithm.
A weaker, but still widespread, model is one where a fraction 1-ε comes from the true distribution, while the remaining ε is chosen by the adversary BMGS17.
Other models include an adversary with only erasing power, or an adversary that must choose its outliers without knowledge of the values of the inliers. Evidently, any guarantee for such weaker poisoning models will also hold for stronger poisoning models DiakonikolasKane19.
High-dimensional robust mean estimates
DiakonikolasKane19 is a great resource for this problem. Below we summarize briefly the key ideas and results.
Unfortunately, results that hold for small dimensions generalize poorly to high dimensions, either because of weak robustness guarantees or computational slowness. Typically, the coordinate-wise median and the geometric median both yield Ω(ε√d)-error, even in the limit of infinite-size datasets and assuming normality for inliers. This is very bad, as today's neural networks often have d~106, if not 109 or 1012 parameters.
On the other hand, assuming the true distribution is a spherical normal distribution N(0,I), Tukey proposed another approach based on identifying the directions of largest variances, since these are likely to be the "attack line" of the adversary Tukey75. Tukey's median yields O(ε)-error with high probability 1-τ for n=Ω((d+log(1/τ))/ε2) data points. Unfortunately, Tukey's median is NP-hard to compute, and is typically exponential in d.
But its ideas can be turned into an polynomial-time algorithm for robust statistics mean estimate. The trick is to identify worst-case "attack line" by computing the largest eigenvalue of the empirical covariance matrix, and to remove extremal points along such lines to reduce variance. DKKLMS16 DKKLMS17 show that, for n=Ω(d/ε2), this yields O(ε√log(1/ε))-error with high probability for sub-Gaussian inliers, and O(σ√ε) for inliners whose true covariance matrix Σ is such that σ2I-Σ is semidefinite positive.
The asymptotical optimal bound O(ε) has been achieved a more sophisticated filtering polynomial-time algorithm by DKKLM+18 for Gaussian distribution in the additive contamination model, while DKS17 showed that no polynomial-time can achieve better than O(ε √log(1/ε)) in the Statistical Query Model with strong contamination.
There are however numerous open questions left. First, the covariance matrix is assumed to be known. This is critical as part of the algorithm requires rescaling coordinates so that the rescaled covariance is the identity matrix. Can we design an efficient robust mean estimator algorithm for bounded but unknown covariance? Second, the algorithms cited above are polynomial-time with respect to d. For large neural networks, this is actually impractical, especially if d>109. Could there be (quasi)-linear-time robust mean estimate algorithms? Finally, it is not clear how relevant this all is in practice for stochastic gradient descent (SGD), today's main machine learning algorithm. Indeed, SGD mostly relies on averaging batch sizes of limited size, say n~103 data points. Yet, the theorems cited above hold for n>d/ε.
Robust statistics for neural networks
It is also noteworthy that, in practice, the main application of robust statistics (at least relevant to AI ethics) seems to be the aggregation of stochastic gradients for neural networks. This setting is often carried out with batches whose size n is significantly smaller than d. Is there a gain in using algorithms more complex than coordinate-wise mean?
BMGS17 proposed Krum and multi-Krum, aggregation algorithms for this setting that have weaker robustness guarantees but are more efficient. Is it possible to improve upon them?