Wednesday, July 13, 2011

More Importance-Aware Updates

The Importance-Aware Update

In importance-weighted regression or classification, the per-example loss is a scaled canonical loss, \[
L (w) = E_{x \sim D} \left[ \epsilon (x) L (x; w) \right],
\] which is optimized via stochastic gradient descent, \[
w_{n+1} = w_n - \eta \epsilon (x_n) \frac{\partial L (x_n; w)}{\partial w}\biggr|_{w = w_n},
\] where $\eta$ is the learning rate. In practice, when the importance weight $\epsilon (x_n)$ gets too large (relative to $\eta$), this update becomes unstable. Unstable here is much worse than saying that the aggregate loss is not reduced by the update; it can easily occur that the loss on this example is not reduced by the update, due to ``over-running the label.'' I think this is why logistic loss is so popular, since it intrinsically guards against this condition; but when using hinge loss, squared loss, or quantile loss, such instability is not difficult to encounter empirically.

The insight of Karampatziakis and Langford is to consider the above update as a first-order Euler integrator for the equation \[
w^\prime (h) = -\eta \epsilon (x_n) \frac{\partial L (x_n; w)}{\partial w}\biggr|_{w = w (h)},
\] which in the case of the GLM has an analytical solution for loss functions commonly used in practice. This is because for the GLM the loss function takes the form $L (x_n; w) = l (w^\top x_n)$, whose gradient is always a scaled version of $x_n$. Therefore there is a solution of the form \[
w (h) = w - s (h) x_n,
\] where \[
s^\prime (h) = \eta \epsilon (x_n) \frac{\partial l}{\partial p}\biggr|_{p = (w - s (h) x_n)^\top x_n},
\] is a one-dimensional ODE with analytical solutions for many common choices of $l$ (e.g., hinge, squared, logistic, etc.).

Although originally motivated by active learning (which generates a wide dynamic range of importance weights), the importance-aware update is useful even for non-importance weighted problems, because it provides robustness with respect to the specification of the learning rate. I've become addicted to the robustness that the importance-aware update empirically exhibits on real data, so when I encounter a new situation I'm always looking for the importance-aware update. I did some investigation of this for dyadic models previously, but there are two other situations where this has come up. These two situations involve loss functions being evaluated on more than one input at a time because of a reduction.

Ranking via Pairwise Classification

For optimizing AUC, there is a reduction to pairwise classification of mixed-label instances. When training with SGD this implies pairs of examples $(x_+, x_-)$ are presented to the classifier, where $x_+$ is ranked above $x_-$. If the desired result is a linear scoring function $f (x; w) = w^\top x$, then the objective is $1_{w^\top (x_+ - x_-) > 0}$, which is relaxed to the following per-example-pair convex loss, \[
\begin{aligned}
L (x_+, x_-; w) &=\epsilon (x_+, x_-) l \left( w^\top (x_+ - x_-), 1 \right), \\
w^\prime (h) &= -\eta \epsilon (x_+, x_-) \frac{\partial L (x_+, x_-; w)}{\partial w}\biggr|_{w = w (h)} \\
&= -\eta \epsilon (x_+, x_-) \frac{\partial l}{\partial p}\biggr|_{p = w (h)^\top (x_+ - x_-)} (x_+ - x_-).
\end{aligned}
\] Once again all the gradients point in the same direction, so look for a solution of the form $w (h) = w - s (h) (x_+ - x_-)$, \[
\begin{aligned}
s^\prime (h) &= -\eta \epsilon (x_+, x_-) \frac{\partial l}{\partial p}\biggr|_{p = (w - s (h) (x_+ - x_-))^\top (x_+ - x_-)}.
\end{aligned}
\] Perhaps unsurprisingly, the importance aware update is the same as for classification, except that it is computed using the difference vector. In other words,
  1. Receive example pair $(x_+, x_-)$ with importance weight $\epsilon (x_+, x_-)$.
  2. Compute standard importance aware update $s (h)$ using $\Delta x = x_+ - x_-$ as the input and $y = 1$ as the target.
    • If it seems strange to you to always have a target of 1, consider that the same inputs might appear transposed at a later time.
    • Note $\Delta x$ does not contain the constant feature.
  3. Update weights using $s (h) \Delta x$.
One important note: this procedure is not the same as sequentially training $x_+$ with label 1 followed by $x_-$ with label 0; that procedure with squared loss corresponds to learning a regressor on a balanced training set which, while consistent, has a worse regret bound than reduction to pairwise classification. You cannot perform the above update with Vowpal Wabbit as a black box (because of the constant feature): you have to open it up and modify the code. (Perhaps a good patch to Vowpal would be to have it accept a command line flag which suppresses the constant feature).

Buffoni et. al. discuss reducing DCG and NDCG to a loss function which is of the same form as above (equation 6 in that paper), so doing importance-aware SGD on their loss would result in a similar update.

Scoring Filter Tree

The scoring filter tree is a reduction from cost-sensitive multiclass classification to importance-weighted binary classification. Previously I've discussed it conceptually and also provided a implementation on top of Vowpal Wabbit via a perl-script wrapper. (I've subsequently written a C++ implementation, which is roughly 10x faster).

A (somewhat) relaxed (but still not convex) objective function for training is \[
\begin{aligned}
L (x; w) &= \sum_{\nu \in T} \left| c_{\lambda (\nu; x, w)} - c_{\rho (\nu; x, w)} \right| l \left(w^\top (x_{\lambda (\nu; x, w)} - x_{\rho (\nu; x, w)}), 1_{c_{\lambda (\nu; x, w)} > c_{\rho (\nu; x, w)}} \right).
\end{aligned}
\] Here $\lambda (\nu; x, w)$ and $\rho (\nu; x, w)$ are the identities of the left and right inputs to node $\nu$. That looks really nasty: in general each input vector occurs multiple times in the tree, and the propagation of inputs to nodes also depends upon $w$ in a non-convex way (i.e., the functions $\lambda$ and $\rho$ are not convex).

Considering only a single level of the tree $k$, \[
\begin{aligned}
L_k (x; w) &= \sum_{\nu \in T_k} \left| c_{\lambda (\nu; x, w)} - c_{\rho (\nu; x, w)} \right| l \left(w^\top (x_{\lambda (\nu; x, w)} - x_{\rho (\nu; x, w)}), 1_{c_{\lambda (\nu; x, w)} > c_{\rho (\nu; x, w)}} \right),
\end{aligned}
\] some progress can be made in this case because the $x$ are orthonormal and each $x$ occurs at most once. Ignoring the dependence of $\lambda$ and $\rho$ on $w$, looking for a solution of the form \[
\begin{aligned}
w (h) &= w - \sum_{\nu \in T_k} s_\nu (h) (x_{\lambda (\nu; x, w)} - x_{\rho (\nu; x, w)}),
\end{aligned}
\] yields \[
s^\prime_\nu (h) = -\eta \left| c_{\lambda (\nu; x, w)} - c_{\rho (\nu; x, w)} \right| \frac{\partial l}{\partial p}\biggr|_{p = (w - s_\nu (h) (x_{\lambda (\nu; x, w)} - x_{\rho (\nu; x, w)}))^\top (x_{\lambda (\nu; x, w)} - x_{\rho (\nu; x, w)})}.
\] This is essentially the ranking update for each node at level $k$ of the tree.

Inspired by the above reasoning, I've gotten really good results from scoring filter trees using the following per-example procedure:
  1. Determine overall winner of the tournament for purposes of estimating loss.
  2. For each level of the tree starting from bottom to top
    1. Determine the inputs to each parent node $\lambda (\nu; x, w)$ and $\rho (\nu; x, w)$, i.e., who advances in the tournament.
    2. Compute $s_\nu (h)$ for each node at this level of the tree, using the above equation, and update the model.
    3. Optional speed improvement: short circuit if all remaining cost vector entries are identical (i.e., all remaining possible importance weights are zero).
    4. Recompute the value of $w^\top x_{\lambda (\nu; x, w)}$ and $w^\top x_{\rho (\nu; x, w)}$ for each remaining class label in the tournament.
This procedure was motivated by the need, for dyadic models, to absolutely rule out the possibility of over-running the label. I found that it gave better losses on non-dyadic problems as well; enough to compensate for the slowdown from recomputing quantities while walking the tree. (The slowdown is even worse than it sounds, because each level of the tree introduces a synchronization barrier between concurrent estimation threads. C'est la vie.)

No comments:

Post a Comment