Tuesday, January 12, 2016

Attention: More Musings

The attention model I posed last post is still reasonable, but the comparison model is not. (These revelations are the fallout of a fun conversation with myself, Nikos, and Sham Kakade. Sham recently took a faculty position at the University of Washington, which is my neck of the woods.)

As a reminder, the attention model is a binary classifier which takes matrix valued inputs $X \in \mathbb{R}^{d \times k}$ with $d$ features and $k$ columns, weights (“attends”) to some columns more than others via parameter $v \in \mathbb{R}^d$, and then predicts with parameter $u \in \mathbb{R}^d$, \[
\begin{aligned}
\hat y &= \mathrm{sgn \;} \left( u^\top X z \right), \\
z &= \frac{\exp \left( v^\top X_i \right)}{\sum_k \exp \left (v^\top X_k \right) }.
\end{aligned}
\] I changed the notation slightly from my last post ($w \rightarrow u$), the reasons for which will be clear shortly. In the previous post the comparison model was an unconstrained linear predictor on all columns, \[
\begin{aligned}
\hat y &= \mathrm{sgn \;} \left( w^\top \mathrm{vec\,} (X) \right),
\end{aligned}
\] with $w \in \mathbb{R}^{d k}$. But this is not a good comparison model because the attention model in nonlinear in ways this model cannot achieve: apples and oranges, really.

This is easier to see with linear attention and a regression task. A linear attention model weights each column according to $(v^\top X_i)$, e.g., $(v^\top X_i)$ is close to zero for “background” or “irrelevant” stuff and is appreciably nonzero for “foreground” or “relevant” stuff. In that case, \[
\begin{aligned}
\hat y &= u^\top X (v^\top X)^\top = \mathrm{tr} \left( X X^\top v u^\top \right),
\end{aligned}
\] (using properties of the trace) which looks like a rank-1 assumption on a full model, \[
\begin{aligned}
\hat y &= \mathrm{tr} \left( X X^\top W \right) = \sum_{ijk} X_{ik} W_{ij} X_{jk} \\
%&= \sum_i \left( X X^\top W \right)_{ii} = \sum_{ij} \left( X X^\top \right) _{ij} W_{ji} \\
%&= \sum_{ijk} X_{ik} X_{jk} W_{ji} = \sum_{ijk} X_{ik} X_{jk} W_{ij}
\end{aligned}
\] where $W \in \mathbb{R}^{d \times d}$ and w.l.o.g. symmetric. (Now hopefully the notation change makes sense: the letters $U$ and $V$ are often used for the left and right singular spaces of the SVD.)

The symmetry of $W$ confuses me, because it suggests $u$ and $v$ are the same (but then the prediction is nonnegative?), so clearly more thinking is required. However this gives a bit of insight, and perhaps this leads to some known results about sample complexity.

No comments:

Post a Comment