Wednesday, January 6, 2016

Attention: Can we formalize it?

In statistics the bias-variance tradeoff is a core concept. Roughly speaking, bias is how well the best hypothesis in your hypothesis class would perform in reality, whereas variance is how much performance degradation is introduced from having finite training data. Abu-Mostafa has a nice lecture on this.

Last century both data and compute were relatively scarce so models that had high bias but low variance (and low computational overhead associated with optimizing over the hypothesis class) were popular: things like generalized linear models. Data became less scarce when media went digital and old ideas with low bias, high variance, and modest computational overhead were revisited: things like n-gram language modeling. The GLM continued to do well in this era because bias-variance tradeoffs could be exploited via feature engineering, e.g., advertising response modeling. Old ideas with low bias and high variance but prohibitive computational overhead continued to be essentially irrelevant (I'm looking at you, k-nearest-neighbors).

If you were ahead of the curve (as I was not!), you could see that the continued relaxation of both data and compute constraints favored lower bias models. However, “easy” decreases in bias that increase variance are still not viable, as we are still unfortunately data constrained given the complexity of the targets we are trying to model (“AI”). So the real game is reducing bias without picking up too much variance. A Bayesian might say “good generic priors”. Joshua Bengio realized this quite some time ago and expressed this view in one of my all-time favorite papers. Section 3.1, in particular, is pure gold. In that section, the authors lay out several key generic priors, e.g., smoothness, hierarchical, multi-task, low intrinsic dimension, multiscale, sparsity, etc.

The closest thing to attention in the list from that great paper is sparsity, which is fairly close in meaning, but I like the term attention better: the important thing for me is dynamic per-example sparsity which is estimated from the “complete” example, where “complete” is perhaps mitigated via hierarchical attention. Attention models have been crushing it lately, e.g., in vision and speech; also I suspect one important reason the deep convolutional architecture is so good at vision is that repeated nonlinear pooling operations are like an attentional mechanism, c.f., figure 2 of Simonyan et. al.. Attention has been crushing it so much that there has to be a way to show the superiority mathematically.

So here's my guess: attention is a good generic prior, and we can formalize this. Unfortunately, theory is not my strong suit, but I think the following might be amenable to analysis. First the setting: the task is binary classification, and the features are matrices $X \in \mathbb{R}^{d \times k}$. The attentional model consists of two vectors $w \in \mathbb{R}^d$ and $v \in \mathbb{R}^d$. The attentional model estimates via \[
\begin{aligned}
\hat y &= \mathrm{sgn\;} \left( w^\top X z \right), \\
z_i &= \frac{\exp \left( v^\top X_i \right)}{\sum_k \exp \left( v^\top X_k \right)},
\end{aligned}
\] i.e., $z \in \Delta^k$ is a softmax which is used to select a weight for each column of $X$, and then $w$ predicts the label linearly given the reduced input $X z \in \mathbb{R}^d$. If hard attention is more your thing, I'm ok with forcing $z$ to be a vertex of the simplex.

The non-attentional model consists of a vector $u \in \mathbb{R}^{k d}$ and estimates via \[
\begin{aligned}
\hat y &= \mathrm{sgn\;} \left( u^\top \mathrm{vec\;} (X) \right),
\end{aligned}
\] i.e., ignores the column structure in $X$, flattens the matrix and then estimates using all the features.

Naive parameter counting (which in general is meaningless) suggests the attentional model (with $2 d$ parameters) is less complicated than the non-attentional model (with $k d$ parameters). However, I'd like to make some more formal statements regarding the bias and variance. In particular my gut says there should be conditions under which the variance is radically reduced, because the final prediction is invariant to things not-attended-to.

If anybody has any ideas on how to make progress, feel free to share (publically right here is fine, or contact me directly if you feel uncomfortable with exposing your sausage manufacturing process). Also feel free to enlighten me if the literature has already addressed these questions.

1 comment:

  1. We are aware of an encouraging result for the case of static attention where the "parts" are features and there is no competition among them (i.e. the attention vector z does not have to sum to 1). This is the same as learning a sparse model and Andrew Ng's analysis of L1 regularization ( http://ai.stanford.edu/~ang/papers/icml04-l1l2.pdf ) shows that it can exponentially reduce sample complexity (from O(number of features) to O(log(number of features)). At the same time, rotationally invariant methods (c.f. the paper above) have to use O(number of features). When I read the paper, long time ago, I did not find the analysis very enlightening, but perhaps the ideas in there are right headed.

    ReplyDelete