Monday, October 3, 2011

A Sparse Online Update for a Hierarchical Model

There is a trick for having sparse updates when using SGD for probabilistic models with prior distributions on the parameters or for classifiers with regularized decision boundaries. As Alex Smola has discussed this has been used for online learning in SVMs for some time, and it's also used in the Vowpal Wabbit LDA implementation. I'll describe it and then describe some challenges I had adapting it for online learning of a hierarchical generative model for crowdsourced data.

The story begins with a training dataset $D$, and a probabilistic model parameterized by $\lambda$ consisting of a likelihood function $L$ and a prior distribution $P$. This results in the objective function $f (D; \lambda) = \log P (\lambda) + \sum_{d \in D} \log L (d | \lambda).$ Moving the prior term into the sum yields $f (D; \lambda) = \sum_{d \in D} \left( \frac{1}{|D|} \log P (\lambda) + \log L (d | \lambda) \right),$ which looks like a candidate for SGD optimization, $\lambda (t+1) = \lambda (t) + \eta (t) \frac{\partial}{\partial \lambda} \left( \frac{1}{|D|} \log P (\lambda (t)) + \log L (d (t) | \lambda (t)) \right).$ It is often the case that the gradient of the likelihood term for a particular datum is non-zero only for a small number of components of $\lambda$, i.e., the likelihood term is sparse. For example, in the crowdsourcing generative model workers who do not label a particular item do not contribute to the likelihood. Unfortunately because the prior term is dense, the resulting SGD update is not sparse. A well-known trick involves noting that in between examples where the likelihood gradient is non-zero, the only update to the parameter is from the prior gradient. Approximating the discrete gradient dynamics with continuous dynamics yields, $\frac{d}{d t} \lambda_i (t) = \frac{1}{|D|} \eta (t) \frac{\partial}{\partial \lambda_i (t)} \log P (\lambda (t)),$ which when integrated from the last update value $\lambda_i (s)$ sometimes has an analytical solution. For example, with a Gaussian prior and power-law decay learning rate, \begin{aligned} \log P (\lambda (t)) &= - \frac{1}{2} \sum_i \left( \lambda_i (t) - \mu \right)^2, \\ \eta (t) &= \eta_0 (t + t_0)^{-\rho}, \\ \lambda_i (t) &= \exp \left( \frac{\eta_0}{|D|} \left( \frac{(s + t_0)^{1 - \rho}}{1 - \rho} - \frac{(t + t_0)^{1 - \rho}}{1 - \rho} \right) \right) \left( \lambda_i (s) - \mu \right) + \mu. \end{aligned} So far so good. Unfortunately my generative models are hierarchical, so the prior distributions are themselves parameterized with hyperparameters with their own hyperpriors. The objective function becomes $f (D; \lambda, \mu) = \log P (\mu) + \log P (\lambda | \mu) + \sum_{d \in D} \log L (d | \lambda),$ and the updates become \begin{aligned} \lambda (t+1) &= \lambda (t) + \eta (t) \frac{\partial}{\partial \lambda} \left( \frac{1}{|D|} \log P (\lambda (t) | \mu) + \log L (d (t) | \lambda (t)) \right), \\ \mu (t+1) &= \mu (t) + \eta (t) \frac{\partial}{\partial \mu} \left(\frac{1}{|D|} \log P (\mu) + \frac{1}{|D|} \log P (\lambda (t) | \mu) \right). \end{aligned} Double yuck! The first problem is that, even when the likelihood gradient is zero, the parameters all have coupled dynamics due to the hyperprior. The second problem is that the gradient computation for the hyperparameter is not sparse (i.e., is linear in the number of parameters). If I were starting from scratch I'd probably say screw it, hyperpriors are not worth it'' but I'm trying to reproduce my results from the batch optimization, where hyperpriors were easier to incorporate and provided some improvement (and diagnostic value).

How to proceed? With another (heretofore unknown?) trick which would apply anytime the hyperprior has an analytical expression for the mode. In my case both the prior and hyperprior are Gaussian, so for fixed $\lambda$ I know what the optimal $\mu$ is, \begin{aligned} \log P (\lambda | \mu) &= -\frac{1}{2} \sum_i (\lambda_i - \mu)^2, \\ \log P (\mu) &= -\frac{1}{2} (\nu - \mu)^2, \\ \mu^* (\lambda) &= \frac{1}{|I| + 1} \left( \nu + \sum_i \lambda_i \right), \end{aligned} where $|I|$ is the number of parameters. This indicates that any time a $\lambda$ changes, the optimal $\mu^*$ experiences the same change scaled by $\frac{1}{|I| + 1}$. This suggests the following procedure:
1. For each $\lambda_i$ which has a non-zero likelihood gradient contribution for the current example,
1. Perform the pure prior'' update approximation as if $\mu$ were constant, and then update $\mu$ for the change in $\lambda_i$, \begin{aligned} \epsilon &\leftarrow \exp \left( \frac{\eta_0}{|D|} \left( \frac{(s + t_0)^{1 - \rho}}{1 - \rho} - \frac{(t + t_0)^{1 - \rho}}{1 - \rho} \right) \right), \\ \Delta \lambda_i &\leftarrow (1 - \epsilon) (\mu - \lambda_i), \\ \lambda_i &\leftarrow \lambda_i + \Delta \lambda_i, \\ \mu &\leftarrow \mu + \frac{1}{|I| + 1} \Delta \lambda_i. \end{aligned} Here $s$ is the timestamp of the last update of $\lambda_i$.
2. Perform the likelihood SGD update for each $\lambda_i$ with a non-zero gradient contribution and propagate the resulting change to $\mu$, \begin{aligned} \Delta \lambda_i &\leftarrow \eta_0 (t + t_0)^{-\rho} \frac{\partial}{\partial \lambda_i} \log L (d | \lambda), \\ \lambda_i &\leftarrow \lambda_i + \Delta \lambda_i, \\ \mu &\leftarrow \mu + \frac{1}{|I| + 1} \Delta \lambda_i. \end{aligned}
Now this is a bit unsatisfactory, because in practice $|I|$ is a free parameter due to the hashing trick, and furthermore will be typically set much larger than the actual number of parameters leveraged during training. In the limit $|I| \to \infty$ this basically reduces to the fixed prior setting where the fixed prior mean is whatever initial value is chosen for $\mu$ (and $\mu = \nu$ is the proper choice for initialization if $\lambda (0) = 0$). In practice this gives reasonable-looking hyperpriors if $|I|$ is not set too large (otherwise, they stay stuck at the hyperprior mean). This is useful because in my nominal label extraction technique, the hypermean confusion matrix can help identify a problem with the task (or a bug somewhere in the data processing).