Sunday, July 17, 2011

Fast Approximate Lambert W

This is of questionable personal utility, after all, I've only seen Lambert's W function once in a machine learning context, and there it occurs as $W (\exp (x)) - x$ which is better to approximate directly. Nonetheless generating these fast approximate functions makes for amusing sport and has a much higher likelihood of being to useful to someone somewhere than, e.g., further developing my computer gaming ability.

The Branchless Lifestyle

Lambert W is a typical function in one respect: there is a nice asymptotic approximation that can be used to initialize a Householder method, but which is only applicable to part of the domain. In particular for large $x$, $W (x) \approx \log (x) - \log \log (x) + \frac{\log \log (x)}{\log (x)},$ which is great because logarithm is cheap. Unfortunately for $x \leq 1$ this cannot be evaluated, and for $1 < x < 2$ it gives really poor results. A natural way to proceed would be to have a different initialization strategy for $x < 2$.
if (x < 2)
{
// alternate initialization
}
else
{
// use asymptotic approximation
}

// householder steps here

There are two problems with this straightforward approach. In a scalar context this can frustrate the pipelined architecture of modern CPUs. In a vector context this is even more problematic because components might fall into different branches of the conditional.

What to do? It turns out statements like this
a = (x < 2) ? b : c

look like conditionals but need not be, since they can be rewritten as
a = f (x < 2) * b + (1 - f (x < 2)) * c

Here $f$ is an indicator function which returns 0 or 1 depending upon the truth value of the argument. The SSE instruction set contains indicator functions for comparison tests, which when combined with floating point and'' instructions end up computing a branchless ternary operator.

The bottom line is that speculative execution can be made deterministic if both branches of a conditional are computed, and in simple enough cases there is direct hardware support for doing this quickly.

Branchless Lambert W

So the big idea here is to have an alternate initialization for the Householder step such that it can be computed in a branchless fashion, given that for large inputs the asymptotic approximation is used. Therefore I looked for an approximation of the form $W (x) \approx a + \log (c x + d) - \log \log (c x + d) + \frac{\log \log (c x + d)}{\log (c x + d)},$ where for large $x$, $a = 0$, $c = 1$, and $d = 0$. I found values for $a$, $c$, $d$, and the cutoff value for $x$ via Mathematica. (The curious can check out the Mathematica notebook). The vector version ends up looking like
// WARNING: this code has been updated.  Do not use this version.

static inline v4sf
vfastlambertw (v4sf x)
{
static const v4sf threshold = v4sfl (2.26445f);

v4sf under = _mm_cmplt_ps (x, threshold);
v4sf c = _mm_or_ps (_mm_and_ps (under, v4sfl (1.546865557f)),
_mm_andnot_ps (under, v4sfl (1.0f)));
v4sf d = _mm_and_ps (under, v4sfl (2.250366841f));
v4sf a = _mm_and_ps (under, v4sfl (-0.737769969f));

v4sf logterm = vfastlog (c * x + d);
v4sf loglogterm = vfastlog (logterm);

v4sf w = a + logterm - loglogterm + loglogterm / logterm;
v4sf expw = vfastexp (w);
v4sf z = w * expw;
v4sf p = x + z;

return (v4sfl (2.0f) * x + w * (v4sfl (4.0f) * x + w * p)) /
(v4sfl (2.0f) * expw + p * (v4sfl (2.0f) + w));
}

You can get the complete code from the fastapprox project.

Timing and Accuracy

Timing tests are done by compiling with -O3 -finline-functions -ffast-math, on a box running 64 bit Ubuntu Lucid (so gcc 4:4.4.3-1ubuntu1 and libc 2.11.1-0ubuntu7.6). I also measured average relative accuracy for $x$ distributed as $(\frac{1}{2} U (-1/e, 1) + \frac{1}{2} U (0, 100))$, i.e., a 50-50 draw from two uniform distributions. Accuracy is compared to 20 iterations of Newton's method with a initial point of 0 when $x < 5$ and the asymptotic approximation otherwise. I also tested the gsl implementation which is much higher accuracy but significantly slower.
fastlambertw average relative error = 5.26867e-05
fastlambertw max relative error (at 2.48955e-06) = 0.0631815
fasterlambertw average relative error = 0.00798678
fasterlambertw max relative error (at -0.00122776) = 0.926378
vfastlambertw average relative error = 5.42952e-05
vfastlambertw max relative error (at -2.78399e-06) = 0.0661513
vfasterlambertw average relative error = 0.00783347
vfasterlambertw max relative error (at -0.00125244) = 0.926431
gsl_sf_lambert_W0 average relative error = 5.90309e-09
gsl_sf_lambert_W0 max relative error (at -0.36782) = 6.67586e-07
fastlambertw million calls per second = 21.2236
fasterlambertw million calls per second = 53.4428
vfastlambertw million calls per second = 21.6723
vfasterlambertw million calls per second = 56.0154
gsl_sf_lambert_W0 million calls per second = 2.7433

These average accuracies hide the relative poor performance at the minimum of the domain. Right at $x = -e^{-1}$, which is the minimum of the domain, the fastlambertw approximation is poor (-0.938, whereas the correct answer is -1; so relative error of $6\%$); but at $x = -e^{-1} + \frac{1}{100}$, the relative error drops to $2 \times 10^{-4}$.

1 comment:

1. Thanks, this was useful. I ran into a use for Lambert W in a machine learning context, although also one where W(exp(x)) would have been more useful.

As you hint, exponentiating y=exp(x) for large x only to compute W(y) risks overflow only to effectively take (something close to) the logarithm again afterwards.