This month, I’ve been working on a neural network to describe in a sentence what’s happening in a picture, otherwise known as image captioning. My model roughly follows the architecture outlined in the paper “Show and Tell: A Neural Image Caption Generator” by Vinyals et al., 2014.
A high level overview: the neural network first uses a convolutional neural network to turn the picture into an abstract representation. Then, it uses this representation as the initial hidden state of a recurrent neural network or LSTM, which generates a natural language sentence. This type of neural network is called an encoder-decoder network and is commonly used for a lot of NLP tasks like machine translation.
Above: Encoder-decoder image captioning neural network (Figure 1 of paper)
When I first encountered LSTMs, I was really confused about how they worked, and how to train them. If your output is a sequence of words, what is your loss function and how do you backpropagate it? In fact, the training and inference passes of an LSTM are quite different. In this blog post, I’ll try to explain this difference.
Above: Training procedure for caption LSTM, given known image and caption
During training mode, we train the neural network to minimize perplexity of the image-caption pair. Perplexity measures how the likelihood that the neural network would generate the given caption when it sees the given image. If we’re training it to output the caption “a cute cat”, the perplexity is:
P(“a” | image) *
P(“cute” | image, “a”) *
P(“cat” | image, “a”, “cute”) * …
(Note: for numerical stability reasons, we typically work with sums of negative log likelihoods rather than products of likelihood probabilities, so perplexity is actually the negative log of that whole thing)
After passing the whole sequence through the LSTM one word at a time, we get a single value, the perplexity, which we can minimize using backpropagation and gradient descent. As perplexity gets lower and lower, the LSTM is more likely to produce similar captions to the ground truth when it sees a similar image. This is how the network learns to caption images.
Above: Inference procedure for caption LSTM, given only the image but no caption
During inference mode, we repeatedly sample the neural network, one word at a time, to produce a sentence. On each step, the LSTM outputs a probability distribution for the next word, over the entire vocabulary. We pick the highest probability word, add it to the caption, and feed it back into the LSTM. This is repeated until the LSTM generates the end marker. Hopefully, if we trained it properly, the resulting sentence will actually describe what’s happening in the picture.
This is the main idea of the paper, and I omitted a lot details. I encourage you to read the paper for the finer points.
I implemented the model using PyTorch and trained it using the MS COCO dataset, which contains about 80,000 images of common objects and situations, and each image is human annotated with 5 captions.
To speed up training, I used a pretrained VGG16 convnet, and pretrained GloVe word embeddings from SpaCy. Using lots of batching, the Adam optimizer, and a Titan X GPU, the neural network trains in about 4 hours. It’s one thing to understand how it works on paper, but watching it actually spit out captions for real images felt like magic.
Above: How I felt when I got this working
How are the results? For some of the images, the neural network does great:
“A train is on the tracks at a station”
“A woman is holding a cat in her arms”
Other times the neural network gets confused, with amusing results:
“A little girl holding a stuffed animal in her hand”
“A baby laying on a bed with a stuffed animal”
“A dog is running with a frisbee in its mouth”
I’d say we needn’t worry about the AI singularity anytime soon 🙂
The original paper has some more examples of correct and incorrect captions that might be generated. Newer models also made improvements to generate more accurate captions: for example, adding a visual attention mechanism improved the results a bit. However, the state-of-the-art models still fall short on human performance; they often make mistakes when describing pictures with objects in unusual configurations.
This is a work in progress; source code is on Github here.