## Tuesday, March 6, 2018

### pytorch PU learning trick

I'm often using positive-unlabeled learning nowadays. In particular for observational dialog modeling, next utterance classification is a standard technique for training and evaluating models. In this setup the observed continuation of the conversation is considered a positive (since a human said it, it is presumed a reasonable thing to say at that point in the conversation) and other randomly chosen utterances are treated as unlabeled (they might be reasonable things to say at that point in the conversation).

Suppose you have a model whose final layer is a dot product between a vector produced only from context and a vector produced only from response. I use models of this form as “level 1” models because they facilitate precomputation of a fast serving index, but note the following trick will not apply to architectures like bidirectional attention. Anyway for these models you can be more efficient during training by drawing the negatives from the same mini-batch. This is a well-known trick but I couldn't find anybody talking about how to do this explicitly in pytorch.

Structure your model to have a leftforward and a rightforward like this:
class MyModel(nn.Module):
...

def forward(leftinput, rightinput):
leftvec = self.leftforward(leftinput)
rightvec = self.rightforward(rightinput)


At training time, compute the leftforward and rightforward for your mini-batch distinctly:
...
criterion = BatchPULoss()
model = MyModel()
...

leftvec = model.leftforward(batch.leftinput)
rightvec = model.rightforward(batch.rightinput)

(loss, preds) = criterion.fortraining(leftvectors, rightvectors)
loss.backward()
# "preds" contains the highest score right for each left
# so for instance, calculate "mini-batch precision at 1"
gold_labels = torch.arange(0, batch.batch_size).long().cuda()
n_correct += (preds.data == gold_labels).sum()
...

Finally use this loss:
import torch

class BatchPULoss():
def __init__(self):
self.loss = torch.nn.CrossEntropyLoss()

def fortraining(self, left, right):
outer = torch.mm(left, right.t())