Friday, April 25, 2014

A Discriminative Representation Learning Technique

Nikos and I have developed a technique for learning discriminative features using numerical linear algebra techniques which gives good results for some problems. The basic idea is as follows. Suppose you have a multiclass problem, i.e., training data of the form $S = \{ (x, y) | x \in \mathbb{R}^d, y \in \{ 1, \ldots, k \} \}$. Here $x$ is the original representation (features) and you want to learn new features that help your classifier. In deep learning this problem is tackled by defining a multi-level parametric nonlinearity of $x$ and optimizing the parameters. Deep learning is awesome but the resulting optimization problems are challenging, especially in the distributed setting, so we were looking for something more computationally felicitous.

First consider the two class case. Imagine looking for features of the form $\phi (w^\top x)$, where $w \in \mathbb{R}^d$ is a “weight vector” and $\phi$ is some nonlinearity. What is a simple criterion for defining a good feature? One idea is for the feature to have small average value on one class and large average value on another. Assuming $\phi$ is non-negative, that suggests maximizing the ratio \[
w^* = \arg \max_w \frac{\mathbb{E}[\phi (w^\top x) | y = 1]}{\mathbb{E}[\phi (w^\top x) | y = 0]}.
\] For the specific choice of $\phi (z) = z^2$ this is tractable, as it results in a Rayleigh quotient between two class-conditional second moments, \[
w^* = \arg \max_w \frac{w^\top \mathbb{E}[x x^\top | y = 1] w}{w^\top \mathbb{E}[x x^\top | y = 0] w},
\] which can be solved via generalized eigenvalue decomposition. Generalized eigenvalue problems have been extensively studied in machine learning and elsewhere, and the above idea looks very similar to many other proposals (e.g., Fisher LDA), but it is different and more empirically effective. I'll refer you to the paper for a more thorough discussion, but I will mention that after the paper was accepted someone pointed out the similarity to CSP, which is a technique from time-series analysis (c.f., Ecclesiastes 1:4-11).

The features that result from this procedure pass the smell test. For example, starting from a raw pixel representation on mnist, the weight vectors can be visualized as images; the first weight vector for discriminating 3 vs. 2 looks like
which looks like a pen stroke, c.f., figure 1D of Ranzato et. al.

We make several additional observations in the paper. The first is that multiple isolated minima of the Rayleigh quotient are useful if the associated generalized eigenvalues are large, i.e., one can extract multiple features from a Rayleigh quotient. The second is that, for moderate $k$, we can extract features for each class pair independently and use all the resulting features to get good results. The third is that the resulting directions have additional structure which is not completely captured by a squaring non-linearity, which motivates a (univariate) basis function expansion. The fourth is that, once the original representation has been augmented with additional features, the procedure can be repeated, which sometimes yields additional improvements. Finally, we can compose this with randomized feature maps to approximate the corresponding operations in a RKHS, which sometimes yields additional improvements. We also made a throw-away comment in the paper that computing class-conditional second moment matrices is easily done in a map-reduce style distributed framework, but this was actually a major motivation for us to explore in this direction, it just didn't fit well into the exposition of the paper so we de-emphasized it.

Combining the above ideas, along with Nikos' preconditioned gradient learning for multiclass described in a previous post, leads to the following Matlab script, which gets 91 test errors on (permutation invariant) mnist. Note: you'll need to download mnist_all.mat from Sam Roweis' site to run this.
function calgevsquared

more off;
clear all;
close all;

start=tic;
load('mnist_all.mat');
xxt=[train0; train1; train2; train3; train4; train5; ...
     train6; train7; train8; train9];
xxs=[test0; test1; test2; test3; test4; test5; test6; test7; test8; test9];
kt=single(xxt)/255;
ks=single(xxs)/255;
st=[size(train0,1); size(train1,1); size(train2,1); size(train3,1); ...
    size(train4,1); size(train5,1); size(train6,1); size(train7,1); ...
    size(train8,1); size(train9,1)];
ss=[size(test0,1); size(test1,1); size(test2,1); size(test3,1); ... 
    size(test4,1); size(test5,1); size(test6,1); size(test7,1); ...
    size(test8,1); size(test9,1)];
paren = @(x, varargin) x(varargin{:});
yt=zeros(60000,10);
ys=zeros(10000,10);
I10=eye(10);
lst=1;
for i=1:10; yt(lst:lst+st(i)-1,:)=repmat(I10(i,:),st(i),1); lst=lst+st(i); end
lst=1;
for i=1:10; ys(lst:lst+ss(i)-1,:)=repmat(I10(i,:),ss(i),1); lst=lst+ss(i); end

clear i st ss lst
clear xxt xxs
clear train0 train1 train2 train3 train4 train5 train6 train7 train8 train9
clear test0 test1 test2 test3 test4 test5 test6 test7 test8 test9

[n,k]=size(yt);
[m,d]=size(ks);

gamma=0.1;
top=20;
for i=1:k
    ind=find(yt(:,i)==1);
    kind=kt(ind,:);
    ni=length(ind);
    covs(:,:,i)=double(kind'*kind)/ni;
    clear ind kind;
end
filters=zeros(d,top*k*(k-1),'single');
last=0;
threshold=0;
for j=1:k
    covj=squeeze(covs(:,:,j)); l=chol(covj+gamma*eye(d))';
    for i=1:k
        if j~=i
            covi=squeeze(covs(:,:,i));
            C=l\covi/l'; CS=0.5*(C+C'); [v,L]=eigs(CS,top); V=l'\v;
            take=find(diag(L)>=threshold);
            batch=length(take);
            fprintf('%u,%u,%u ', i, j, batch);
            filters(:,last+1:last+batch)=V(:,take);
            last=last+batch;
        end
    end
    fprintf('\n');
end

clear covi covj covs C CS V v L

% NB: augmenting kt/ks with .^2 terms is very slow and doesn't help

filters=filters(:,1:last);
ft=kt*filters;
clear kt;
kt=[ones(n,1,'single') sqrt(1+max(ft,0))-1 sqrt(1+max(-ft,0))-1];
clear ft;
fs=ks*filters;
clear ks filters;
ks=[ones(m,1,'single') sqrt(1+max(fs,0))-1 sqrt(1+max(-fs,0))-1];
clear fs;

[n,k]=size(yt);
[m,d]=size(ks);

for i=1:k
    ind=find(yt(:,i)==1);
    kind=kt(ind,:);
    ni=length(ind);
    covs(:,:,i)=double(kind'*kind)/ni;
    clear ind kind;
end

filters=zeros(d,top*k*(k-1),'single');
last=0;
threshold=7.5;
for j=1:k
    covj=squeeze(covs(:,:,j)); l=chol(covj+gamma*eye(d))';
    for i=1:k
        if j~=i
            covi=squeeze(covs(:,:,i));
            C=l\covi/l'; CS=0.5*(C+C'); [v,L]=eigs(CS,top); V=l'\v;
            take=find(diag(L)>=threshold);
            batch=length(take);
            fprintf('%u,%u,%u ', i, j, batch);
            filters(:,last+1:last+batch)=V(:,take);
            last=last+batch;
        end
    end
    fprintf('\n');
end
fprintf('gamma=%g,top=%u,threshold=%g\n',gamma,top,threshold);
fprintf('last=%u filtered=%u\n', last, size(filters,2) - last);

clear covi covj covs C CS V v L

filters=filters(:,1:last);
ft=kt*filters;
clear kt;
kt=[sqrt(1+max(ft,0))-1 sqrt(1+max(-ft,0))-1];
clear ft;
fs=ks*filters;
clear ks filters;
ks=[sqrt(1+max(fs,0))-1 sqrt(1+max(-fs,0))-1];
clear fs;

trainx=[ones(n,1,'single') kt kt.^2];
clear kt;
testx=[ones(m,1,'single') ks ks.^2];
clear ks;

C=chol(0.5*(trainx'*trainx)+sqrt(n)*eye(size(trainx,2)),'lower');
w=C'\(C\(trainx'*yt));
pt=trainx*w;
ps=testx*w;

[~,trainy]=max(yt,[],2);
[~,testy]=max(ys,[],2);

for i=1:5
        xn=[pt pt.^2/2 pt.^3/6 pt.^4/24];
        xm=[ps ps.^2/2 ps.^3/6 ps.^4/24];
        c=chol(xn'*xn+sqrt(n)*eye(size(xn,2)),'lower');
        ww=c'\(c\(xn'*yt));
        ppt=SimplexProj(xn*ww);
        pps=SimplexProj(xm*ww);
        w=C'\(C\(trainx'*(yt-ppt)));
        pt=ppt+trainx*w;
        ps=pps+testx*w;

        [~,yhatt]=max(pt,[],2);
        [~,yhats]=max(ps,[],2);
        errort=sum(yhatt~=trainy)/n;
        errors=sum(yhats~=testy)/m;
        fprintf('%u,%g,%g\n',i,errort,errors)
end
fprintf('%4s\t', 'pred');
for true=1:k
        fprintf('%5u', true-1);
end
fprintf('%5s\n%4s\n', '!=', 'true');
for true=1:k
        fprintf('%4u\t', true-1);
        trueidx=find(testy==true);
        for predicted=1:k
                predidx=find(yhats(trueidx)==predicted);
                fprintf('%5u', sum(predidx>0));
        end
        predidx=find(yhats(trueidx)~=true);
        fprintf('%5u\n', sum(predidx>0));
end

toc(start)

end

% http://arxiv.org/pdf/1309.1541v1.pdf
function X = SimplexProj(Y)
  [N,D] = size(Y);
  X = sort(Y,2,'descend');
  Xtmp = bsxfun(@times,cumsum(X,2)-1,(1./(1:D)));
  X = max(bsxfun(@minus,Y,Xtmp(sub2ind([N,D],(1:N)',sum(X>Xtmp,2)))),0);
end
When I run this on my desktop machine it yields
>> calgevsquared
2,1,20 3,1,20 4,1,20 5,1,20 6,1,20 7,1,20 8,1,20 9,1,20 10,1,20 
1,2,20 3,2,20 4,2,20 5,2,20 6,2,20 7,2,20 8,2,20 9,2,20 10,2,20 
1,3,20 2,3,20 4,3,20 5,3,20 6,3,20 7,3,20 8,3,20 9,3,20 10,3,20 
1,4,20 2,4,20 3,4,20 5,4,20 6,4,20 7,4,20 8,4,20 9,4,20 10,4,20 
1,5,20 2,5,20 3,5,20 4,5,20 6,5,20 7,5,20 8,5,20 9,5,20 10,5,20 
1,6,20 2,6,20 3,6,20 4,6,20 5,6,20 7,6,20 8,6,20 9,6,20 10,6,20 
1,7,20 2,7,20 3,7,20 4,7,20 5,7,20 6,7,20 8,7,20 9,7,20 10,7,20 
1,8,20 2,8,20 3,8,20 4,8,20 5,8,20 6,8,20 7,8,20 9,8,20 10,8,20 
1,9,20 2,9,20 3,9,20 4,9,20 5,9,20 6,9,20 7,9,20 8,9,20 10,9,20 
1,10,20 2,10,20 3,10,20 4,10,20 5,10,20 6,10,20 7,10,20 8,10,20 9,10,20 
2,1,15 3,1,20 4,1,20 5,1,20 6,1,20 7,1,20 8,1,20 9,1,20 10,1,20 
1,2,20 3,2,20 4,2,20 5,2,20 6,2,20 7,2,20 8,2,20 9,2,20 10,2,20 
1,3,20 2,3,11 4,3,17 5,3,20 6,3,20 7,3,19 8,3,18 9,3,18 10,3,19 
1,4,20 2,4,12 3,4,20 5,4,20 6,4,12 7,4,20 8,4,19 9,4,15 10,4,20 
1,5,20 2,5,12 3,5,20 4,5,20 6,5,20 7,5,20 8,5,16 9,5,20 10,5,9 
1,6,18 2,6,13 3,6,20 4,6,12 5,6,20 7,6,18 8,6,20 9,6,13 10,6,18 
1,7,20 2,7,14 3,7,20 4,7,20 5,7,20 6,7,20 8,7,20 9,7,20 10,7,20 
1,8,20 2,8,14 3,8,20 4,8,20 5,8,20 6,8,20 7,8,20 9,8,20 10,8,12 
1,9,20 2,9,9 3,9,20 4,9,15 5,9,18 6,9,11 7,9,20 8,9,17 10,9,16 
1,10,20 2,10,14 3,10,20 4,10,20 5,10,14 6,10,20 7,10,20 8,10,12 9,10,20 
gamma=0.1,top=20,threshold=7.5
last=1630 filtered=170
1,0.0035,0.0097
2,0.00263333,0.0096
3,0.00191667,0.0092
4,0.00156667,0.0093
5,0.00141667,0.0091
pred        0    1    2    3    4    5    6    7    8    9   !=
true
   0      977    0    1    0    0    1    0    1    0    0    3
   1        0 1129    2    1    0    0    1    1    1    0    6
   2        1    1 1020    0    1    0    0    6    3    0   12
   3        0    0    1 1004    0    1    0    2    1    1    6
   4        0    0    0    0  972    0    4    0    2    4   10
   5        1    0    0    5    0  883    2    1    0    0    9
   6        4    2    0    0    2    2  947    0    1    0   11
   7        0    2    5    0    0    0    0 1018    1    2   10
   8        1    0    1    1    1    1    0    1  966    2    8
   9        1    1    0    2    5    2    0    4    1  993   16
Elapsed time is 186.147659 seconds.
That's a pretty good confusion matrix, comparable to state-of-the-art deep learning results on (permutation invariant) mnist. In the paper we report a slightly worse number (96 test errors) because for a paper we have to choose hyperparameters via cross-validation on the training set rather than cherry-pick them as for a blog post.

The technique as stated here is really only useful for tall-thin design matrices (i.e., lots of examples but not too many features): if the original feature dimensionality is too large (e.g., $> 10^4$) than naive use of standard generalized eigensolvers becomes slow or infeasible, and other tricks are required. Furthermore, if the number of classes is too large than solving $O (k^2)$ generalized eigenvalue problems is also not reasonable. We're working on remedying these issues, and we're also excited about extending this strategy to structured prediction. Hopefully we'll have more to say about it at the next few conferences.