Friday, June 24, 2011

Faster LDA

Now that I have mastered the unholy power of byte reinterpretation between integers and IEEE floating point, I've been looking around for things to speed up. The recent release of a fast LDA implementation by Smola et. al. reminded me to look at Vowpal Wabbit's LDA implementation. The lda.cc file is full of expensive looking transcendental functions that could be sped up using approximate versions, hopefully without spoiling the learning algorithm.

Faster Log Gamma and Digamma

In addition to lots of calls to exponential and logarithm, for which I have already have fast approximations, there are also calls to Log Gamma and Digamma. Fortunately the leading term in Stirling's approximation is a logarithm so I can leverage my fast logarithm approximation from my previous blog post.

For both Log Gamma and Digamma I found applying the shift formula twice gave the best tradeoff between accuracy and speed.
// WARNING: this code has been updated.  Do not use this version.
// Instead get the latest version from http://code.google.com/p/fastapprox

inline float
fastlgamma (float x)
{
  float logterm = fastlog (x * (1.0f + x) * (2.0f + x));
  float xp3 = 3.0f + x;

  return
    -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3);
}

inline float
fastdigamma (float x)
{
  float twopx = 2.0f + x;
  float logterm = fastlog (twopx);

  return - (1.0f + 2.0f * x) / (x * (1.0f + x))
         - (13.0f + 6.0f * x) / (12.0f * twopx * twopx)
         + logterm;
}
Timing and Accuracy
Timing tests are done by compiling with -O3 -finline-functions -ffast-math, on a box running 64 bit Ubuntu Lucid (so gcc 4:4.4.3-1ubuntu1 and libc 2.11.1-0ubuntu7.6). I also measured average relative accuracy for $x \in [1/100, 10]$ (using lgammaf from libc as the gold standard for Log Gamma, and boost::math::digamma as the gold standard for Digamma).
fastlgamma relative accuracy = 0.00045967
fastdigamma relative accuracy = 0.000420604
fastlgamma million calls per second = 60.259
lgammaf million calls per second = 21.4703
boost::math::lgamma million calls per second = 1.8951
fastdigamma million calls per second = 66.0639
boost::math::digamma million calls per second = 18.9672
Well the speedup is roughly 3x for Log Gamma (but why is boost::math::lgamma so slow?) and slightly better than 3x for Digamma.

LDA Timings

So I dropped these approximations into lda.cc and did some timing tests, using the wiki1K file from the test/train-sets/ directory; but I concatenated the file with itself 100 times to slow things down.
Original Run
% (cd test; time ../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatch 128 train-sets/wiki1Kx100.dat)
your learning rate is too high, setting it to 1
using no cache
Reading from train-sets/wiki1Kx100.dat
num sources = 1
Num weight bits = 13
learning rate = 1
initial_t = 1
power_t = 0.5
learning_rate set to 1
average    since       example  example    current  current  current
loss       last        counter   weight      label  predict features
10.296875  10.296875         3      3.0    unknown   0.0000       38
10.437156  10.577436         6      6.0    unknown   0.0000       14
10.347227  10.239314        11     11.0    unknown   0.0000       32
10.498633  10.650038        22     22.0    unknown   0.0000        2
10.495566  10.492500        44     44.0    unknown   0.0000      166
10.469184  10.442189        87     87.0    unknown   0.0000       29
10.068007  9.666831        174    174.0    unknown   0.0000       17
9.477440   8.886873        348    348.0    unknown   0.0000        2
9.020482   8.563524        696    696.0    unknown   0.0000      143
8.482232   7.943982       1392   1392.0    unknown   0.0000       31
8.106654   7.731076       2784   2784.0    unknown   0.0000       21
7.850269   7.593883       5568   5568.0    unknown   0.0000       25
7.707427   7.564560      11135  11135.0    unknown   0.0000       39
7.627417   7.547399      22269  22269.0    unknown   0.0000       61
7.583820   7.540222      44537  44537.0    unknown   0.0000        5
7.558824   7.533827      89073  89073.0    unknown   0.0000      457

finished run
number of examples = 0
weighted example sum = 0
weighted label sum = 0
average loss = -nan
best constant = -nan
total feature number = 0
../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatc  69.06s user 13.60s system 101% cpu 1:21.59 total
Approximate Run
% (cd test; time ../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatch 128 train-sets/wiki1Kx100.dat)               your learning rate is too high, setting it to 1
using no cache
Reading from train-sets/wiki1Kx100.dat
num sources = 1
Num weight bits = 13
learning rate = 1
initial_t = 1
power_t = 0.5
learning_rate set to 1
average    since       example  example    current  current  current
loss       last        counter   weight      label  predict features
10.297077  10.297077         3      3.0    unknown   0.0000       38
10.437259  10.577440         6      6.0    unknown   0.0000       14
10.347348  10.239455        11     11.0    unknown   0.0000       32
10.498796  10.650243        22     22.0    unknown   0.0000        2
10.495748  10.492700        44     44.0    unknown   0.0000      166
10.469374  10.442388        87     87.0    unknown   0.0000       29
10.068179  9.666985        174    174.0    unknown   0.0000       17
9.477574   8.886968        348    348.0    unknown   0.0000        2
9.020435   8.563297        696    696.0    unknown   0.0000      143
8.482178   7.943921       1392   1392.0    unknown   0.0000       31
8.106636   7.731093       2784   2784.0    unknown   0.0000       21
7.849980   7.593324       5568   5568.0    unknown   0.0000       25
7.707124   7.564242      11135  11135.0    unknown   0.0000       39
7.627207   7.547283      22269  22269.0    unknown   0.0000       61
7.583952   7.540696      44537  44537.0    unknown   0.0000        5
7.559260   7.534566      89073  89073.0    unknown   0.0000      457

finished run
number of examples = 0
weighted example sum = 0
weighted label sum = 0
average loss = -nan
best constant = -nan
total feature number = 0
../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatc  41.97s user 7.69s system 102% cpu 48.622 total
So from 81 seconds down to 49 seconds, a roughly 40% speed increase. But does it work? Well I'll let Matt be the final judge of that (I've sent him a modified lda.cc for testing), but initial indications are that using approximate versions of the expensive functions doesn't spoil the learning process.

Next Step: Vectorization

There is definitely more speed to be had, since lda.cc is loaded with code like
for (size_t i = 0; i<global.lda; i++)
    {
      sum += gamma[i];
      gamma[i] = mydigamma(gamma[i]);
    }
which just screams for vectorization. Stay tuned!

No comments:

Post a Comment