PyTorch 高级篇(4):图像标注(Image Captioning (CNN-RNN))

PyTorch

PyTorch 高级篇(4):图像标注(Image Captioning (CNN-RNN))

参考代码

yunjey的 pytorch tutorial系列

我的远程服务器没啥可视化界面可看,就把大神代码转到jupyter上看看效果

图像标注 学习资料

相关论文

Show and Tell: A Neural Image Caption Generator

相关博客

看图说话的AI小朋友——图像标注趣谈(上)

「Show and Tell」——图像标注(Image Caption)任务技术综述

CVPR2017 Image Caption有关论文总结

图像标注 简介

图像标注也是牛逼轰轰的东西,今天就算搞不清楚,至少懂个大概也知足了。后面估计还是要涉足的。

图像标注就是将输入的图像转换为自然语言描述。编码器-解码器(encoder-decoder)架构被广泛用于这项任务。

图像编码器是一个卷积神经网络,这份代码里用的是resnet-152模型,它已经在 ILSVRC-2012-CLS 图像分类数据集上已经提前训练好了,

解码器使用的是长短期记忆(LSTM)网络。

图片来源:Image Captioning

Image Caption流程

图像标注的流程可以从RNN开始拓展,在机器翻译任务中,输入输出是单词序列,通过Encoder-Decoder结构。其中Encoder得到的是特征序列,因此,在图像标注中,将Encoder部分替换为图像输入+CNN提取特征(视觉特征),同样得到的特征序列供Decoder解码,即可。

当然,在上图的架构中,CNN使用的是resnet-152,Decoder部分使用了性能更好的LSTM。

训练和测试 过程介绍

训练阶段

对于编码器部分,预训练好的CNN模型会从给定的输入图像中提取特征向量。将特征向量进行线性转换,使之与LSTM网络的输入具有相同的维度。

对于解码器部分,源文本和目标文本已经预定义好。举例来说,如果图像的描述为”Giraffes standing next to each other”,那么源序列应该是一个集合,包含了 [‘‘, ‘Giraffes’, ‘standing’, ‘next’, ‘to’, ‘each’, ‘other’] ,且目标序列应该是一个集合包含了[‘Giraffes’, ‘standing’, ‘next’, ‘to’, ‘each’, ‘other’, ‘‘]. 使用这些源序列、目标序列和特征向量,可以将LSTM解码器训练为一个基于特征向量的语言模型。

测试阶段

在测试阶段,编码器部分几乎和训练阶段相同。唯一的区别就是批归一化层(batch norm layer)使用移动平均和方差,而不是mini-batch统计,
这个直接通过调用encoder.eval()就实现了。

对于解码器部分,与训练阶段相比,明显的区别就是:在测试阶段,LSTM解码器不能看到图像的描述。为了解决这个问题,LSTM解码器将前一次生成的单词反馈作为下一次的输入。这个可以通过for循环实现。

安装COCO PythonAPI 和 下载数据集

安装COCO API
详情请咨询 COCO

1
2
3
4
5
$ git clone https://github.com/pdollar/coco.git
$ cd coco/PythonAPI/
$ make
$ python setup.py build
$ python setup.py install

数据集下载
下载脚本,运行脚本(十几个G略慢)

1
2
3
$ wget https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/download.sh
$ chmod +x download.sh
$ ./download.sh

预处理

构建词汇表 和 图像大小缩放,适配数据集加载器Dataloader。

1. 构建词汇表

refer: build_vocab.py

1
2
3
4
5
6
# 包
import nltk
import pickle
import argparse
from collections import Counter
from pycocotools.coco import COCO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 定义简单的词汇封装器
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0

def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1

def __call__(self, word):
if not word in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]

def __len__(self):
return len(self.word2idx)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 函数:构建词汇
def build_vocab(json, threshold):
"""Build a simple vocabulary wrapper."""
coco = COCO(json)
counter = Counter()
ids = coco.anns.keys()
for i, id in enumerate(ids):
caption = str(coco.anns[id]['caption'])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)

if (i+1) % 1000 == 0:
print("[{}/{}] Tokenized the captions.".format(i+1, len(ids)))

# If the word frequency is less than 'threshold', then the word is discarded.
words = [word for word, cnt in counter.items() if cnt >= threshold]

# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word('<pad>')
vocab.add_word('<start>')
vocab.add_word('<end>')
vocab.add_word('<unk>')

# Add the words to the vocabulary.
for i, word in enumerate(words):
vocab.add_word(word)
return vocab
1
2
3
4
5
6
7
8
# 定义执行函数
def build_vocab_main(args):
vocab = build_vocab(json=args.caption_path, threshold=args.threshold)
vocab_path = args.vocab_path
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))
1
2
3
4
5
6
7
8
9
10
11
12
# 通过argparse传参数,并运行构建词汇库 
# 注意指定数据集相关文件的路径
parser = argparse.ArgumentParser()
parser.add_argument('--caption_path', type=str,
default='/home/ubuntu/Datasets/coco/annotations/captions_train2014.json',
help='path for train annotation file')
parser.add_argument('--vocab_path', type=str, default='/home/ubuntu/Datasets/coco/vocab.pkl',
help='path for saving vocabulary wrapper')
parser.add_argument('--threshold', type=int, default=4,
help='minimum word count threshold')
config = parser.parse_args(args=[])
build_vocab_main(config)
loading annotations into memory...
Done (t=0.89s)
creating index...
index created!
[1000/414113] Tokenized the captions.
[2000/414113] Tokenized the captions.
[3000/414113] Tokenized the captions.
[4000/414113] Tokenized the captions.
[5000/414113] Tokenized the captions.

..........................

[409000/414113] Tokenized the captions.
[410000/414113] Tokenized the captions.
[411000/414113] Tokenized the captions.
[412000/414113] Tokenized the captions.
[413000/414113] Tokenized the captions.
[414000/414113] Tokenized the captions.
Total vocabulary size: 9957
Saved the vocabulary wrapper to '/home/ubuntu/Datasets/coco/vocab.pkl'

2. 图像Resize操作

refer:resize.py

1
2
3
4
# 包
import argparse
import os
from PIL import Image
1
2
3
4
# 定义函数 Resize图像
def resize_image(image, size):
"""Resize an image to the given size."""
return image.resize(size, Image.ANTIALIAS)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 定义函数 Resize 图像序列 
def resize_images(image_dir, output_dir, size):
"""Resize the images in 'image_dir' and save into 'output_dir'."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)

images = os.listdir(image_dir)
num_images = len(images)
for i, image in enumerate(images):
with open(os.path.join(image_dir, image), 'r+b') as f:
with Image.open(f) as img:
img = resize_image(img, size)
img.save(os.path.join(output_dir, image), img.format)
if (i+1) % 100 == 0:
print ("[{}/{}] Resized the images and saved into '{}'."
.format(i+1, num_images, output_dir))
1
2
3
4
5
6
# 定义主函数
def resize_main(args):
image_dir = args.image_dir
output_dir = args.output_dir
image_size = [args.image_size, args.image_size]
resize_images(image_dir, output_dir, image_size)
1
2
3
4
5
6
7
8
9
10
11
# 通过argparse传参数
# 注意指定数据集相关文件的路径
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='/home/ubuntu/Datasets/coco/train2014/',
help='directory for train images')
parser.add_argument('--output_dir', type=str, default='/home/ubuntu/Datasets/coco/resized2014/',
help='directory for saving resized images')
parser.add_argument('--image_size', type=int, default=256,
help='size for image after processing')
config = parser.parse_args(args=[])
resize_main(config)
[100/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[200/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[300/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[400/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[500/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[600/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[700/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[800/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.

............................

[82400/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[82500/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[82600/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.
[82700/82783] Resized the images and saved into '/home/ubuntu/Datasets/coco/resized2014/'.

3. 数据集加载函数

refer: data_loader.py

1
2
3
4
5
6
7
8
9
10
11
# 包
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
import nltk
from PIL import Image
# from build_vocab import Vocabulary # jupyter上已经定义好了函数
from pycocotools.coco import COCO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 创建CocoDataset类,以适配Pytorch的数据加载类
class CocoDataset(data.Dataset):
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
def __init__(self, root, json, vocab, transform=None):
"""Set the path for images, captions and vocabulary wrapper.

Args:
root: image directory.
json: coco annotation file path.
vocab: vocabulary wrapper.
transform: image transformer.
"""
self.root = root
self.coco = COCO(json)
self.ids = list(self.coco.anns.keys())
self.vocab = vocab
self.transform = transform

def __getitem__(self, index):
"""Returns one data pair (image and caption)."""
coco = self.coco
vocab = self.vocab
ann_id = self.ids[index]
caption = coco.anns[ann_id]['caption']
img_id = coco.anns[ann_id]['image_id']
path = coco.loadImgs(img_id)[0]['file_name']

image = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
image = self.transform(image)

# Convert caption (string) to word ids.
# 将描述(字符串)转换为单词ID
tokens = nltk.tokenize.word_tokenize(str(caption).lower())
caption = []
caption.append(vocab('<start>'))
caption.extend([vocab(token) for token in tokens])
caption.append(vocab('<end>'))
target = torch.Tensor(caption)
return image, target

def __len__(self):
return len(self.ids)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 函数:从元组(image,caption)序列创建 mini-batch
# 之所以需要自己创建是因为默认不支持merging caption (including padding) i
def collate_fn(data):
"""Creates mini-batch tensors from the list of tuples (image, caption).

We should build custom collate_fn rather than using default collate_fn,
because merging caption (including padding) is not supported in default.
Args:
data: list of tuple (image, caption).
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length (descending order).
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions = zip(*data)

# Merge images (from tuple of 3D tensor to 4D tensor).
images = torch.stack(images, 0)

# Merge captions (from tuple of 1D tensor to 2D tensor).
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return images, targets, lengths
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 获取数据加载器
def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
# COCO caption dataset
coco = CocoDataset(root=root,
json=json,
vocab=vocab,
transform=transform)

# Data loader for COCO dataset
# This will return (images, captions, lengths) for each iteration.
# images: a tensor of shape (batch_size, 3, 224, 224).
# captions: a tensor of shape (batch_size, padded_length).
# lengths: a list indicating valid length for each caption. length is (batch_size).
data_loader = torch.utils.data.DataLoader(dataset=coco,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader

创建模型(编码器和解码器 Encoder and Decoder)

1
2
3
4
5
# 包
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 编码器 ResNet-152模型
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
"""Load the pretrained ResNet-152 and replace top fc layer."""
# 替换最后一层全连接层
super(EncoderCNN, self).__init__()
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

def forward(self, images):
"""Extract feature vectors from input images."""
# 编码器用来实现提取特征,并不需要计算梯度
with torch.no_grad():
features = self.resnet(images)
features = features.reshape(features.size(0), -1)
features = self.bn(self.linear(features))
return features
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# 解码器 LSTM模型
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
"""Set the hyper-parameters and build the layers."""
# 超参数通过传参设置
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.max_seg_length = max_seq_length

def forward(self, features, captions, lengths):
"""Decode image feature vectors and generates captions."""
# 前向传播是指将图像特征向量进行解码生成描述
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
return outputs

def sample(self, features, states=None):
"""Generate captions for given image features using greedy search."""
# 使用贪心搜索从给定的图像特征生成描述
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(self.max_seg_length):
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
_, predicted = outputs.max(1) # predicted: (batch_size)
sampled_ids.append(predicted)
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
return sampled_ids

训练模型

refertrain.py

1
2
3
4
5
6
7
8
9
10
11
12
# 包
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
# from data_loader import get_loader
# from build_vocab import Vocabulary
# from model import EncoderCNN, DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
1
2
3
# 设备配置
torch.cuda.set_device(1) # 这句用来设置pytorch在哪块GPU上运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# 设置训练过程
def train_main(args):
# 创建模型模型的文件夹
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)

# 配置transform: 图像预处理、归一化
transform = transforms.Compose([
transforms.RandomCrop(args.crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])

# Load vocabulary wrapper
# 加载词汇封装文件
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)

# 新建数据加载器
data_loader = get_loader(args.image_dir, args.caption_path, vocab,
transform, args.batch_size,
shuffle=True, num_workers=args.num_workers)

# 实例化编码器和解码器
encoder = EncoderCNN(args.embed_size).to(device)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)

# 社会损失函数和优化器
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)

# 训练模型
total_step = len(data_loader)
for epoch in range(args.num_epochs):
for i, (images, captions, lengths) in enumerate(data_loader):

# 设置mini-batch数据集
images = images.to(device)
captions = captions.to(device)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

# 前向传播-》反向传播-》优化
features = encoder(images)
outputs = decoder(features, captions, lengths)
loss = criterion(outputs, targets)

decoder.zero_grad() # 切记
encoder.zero_grad() #切记
loss.backward()

optimizer.step()

# 打印Log信息
if i % args.log_step == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))

# 定期保存模型
if (i+1) % args.save_step == 0:
torch.save(decoder.state_dict(), os.path.join(
args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
torch.save(encoder.state_dict(), os.path.join(
args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 设置参数
parser = argparse.ArgumentParser()

# 设置文件夹路径
parser.add_argument('--model_path', type=str, default='models/' , help='path for saving trained models')
parser.add_argument('--crop_size', type=int, default=224 , help='size for randomly cropping images')
parser.add_argument('--vocab_path', type=str, default='/home/ubuntu/Datasets/coco/vocab.pkl', help='path for vocabulary wrapper')
parser.add_argument('--image_dir', type=str, default='/home/ubuntu/Datasets/coco/resized2014', help='directory for resized images')
parser.add_argument('--caption_path', type=str, default='/home/ubuntu/Datasets/coco/annotations/captions_train2014.json', help='path for train annotation json file')

# 设置打印信息步长和保存步长
parser.add_argument('--log_step', type=int , default=10, help='step size for prining log info')
parser.add_argument('--save_step', type=int , default=1000, help='step size for saving trained models')

# 设置模型参数
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')

# 设置超参数
parser.add_argument('--num_epochs', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--learning_rate', type=float, default=0.001)

config = parser.parse_args(args=[])

print(config)

train_main(config)
Namespace(batch_size=128, caption_path='/home/ubuntu/Datasets/coco/annotations/captions_train2014.json', crop_size=224, embed_size=256, hidden_size=512, image_dir='/home/ubuntu/Datasets/coco/resized2014', learning_rate=0.001, log_step=10, model_path='models/', num_epochs=5, num_layers=1, num_workers=2, save_step=1000, vocab_path='/home/ubuntu/Datasets/coco/vocab.pkl')
loading annotations into memory...
Done (t=0.75s)
creating index...
index created!
Epoch [0/5], Step [0/3236], Loss: 9.2119, Perplexity: 10015.3099
Epoch [0/5], Step [10/3236], Loss: 5.8402, Perplexity: 343.8538
Epoch [0/5], Step [20/3236], Loss: 5.4097, Perplexity: 223.5633
Epoch [0/5], Step [30/3236], Loss: 4.9667, Perplexity: 143.5454
Epoch [0/5], Step [40/3236], Loss: 4.7254, Perplexity: 112.7781
Epoch [0/5], Step [50/3236], Loss: 4.4457, Perplexity: 85.2637
Epoch [0/5], Step [60/3236], Loss: 4.3398, Perplexity: 76.6949

.........................

Epoch [4/5], Step [3140/3236], Loss: 2.0148, Perplexity: 7.4993
Epoch [4/5], Step [3150/3236], Loss: 1.9162, Perplexity: 6.7949
Epoch [4/5], Step [3160/3236], Loss: 1.8994, Perplexity: 6.6816
Epoch [4/5], Step [3170/3236], Loss: 1.7569, Perplexity: 5.7942
Epoch [4/5], Step [3180/3236], Loss: 1.8736, Perplexity: 6.5118
Epoch [4/5], Step [3190/3236], Loss: 1.9967, Perplexity: 7.3650
Epoch [4/5], Step [3200/3236], Loss: 1.8380, Perplexity: 6.2840
Epoch [4/5], Step [3210/3236], Loss: 1.9305, Perplexity: 6.8927
Epoch [4/5], Step [3220/3236], Loss: 1.9491, Perplexity: 7.0224
Epoch [4/5], Step [3230/3236], Loss: 1.8040, Perplexity: 6.0742

模型测试

refer: sample.py

1
2
3
4
5
6
7
8
9
10
11
# 包
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
from torchvision import transforms
# from build_vocab import Vocabulary
# from model import EncoderCNN, DecoderRNN
from PIL import Image
1
2
3
# 设备配置
torch.cuda.set_device(1) # 这句用来设置pytorch在哪块GPU上运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1
2
3
4
5
6
7
8
9
# 定义函数 加载图像
def load_image(image_path, transform=None):
image = Image.open(image_path)
image = image.resize([224, 224], Image.LANCZOS)

if transform is not None:
image = transform(image).unsqueeze(0)

return image
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# 定义测试函数
def test(args):
# 图像预处理模块
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])

# 加载词汇表封装
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)

# 建立两个模型
encoder = EncoderCNN(args.embed_size).eval() # 切换成评估模式 (即批归一化使用移动 均值/方差)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
encoder = encoder.to(device)
decoder = decoder.to(device)

# 加载训练好的模型的参数
encoder.load_state_dict(torch.load(args.encoder_path))
decoder.load_state_dict(torch.load(args.decoder_path))

# 准备一幅图像
image = load_image(args.image, transform)
image_tensor = image.to(device)

# 从图像生成描述
feature = encoder(image_tensor)
sampled_ids = decoder.sample(feature)
sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)

# 将词ID转换为单词
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)

# 打印图像和描述
print (sentence)
image = Image.open(args.image)
plt.imshow(np.asarray(image))
1
2
3
4
5
6
7
8
9
10
11
12
13
# 设置参数进行测试
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, default='png/football2.jpg', help='input image for generating caption')
parser.add_argument('--encoder_path', type=str, default='models/encoder-5-3000.ckpt', help='path for trained encoder')
parser.add_argument('--decoder_path', type=str, default='models/decoder-5-3000.ckpt', help='path for trained decoder')
parser.add_argument('--vocab_path', type=str, default='/home/ubuntu/Datasets/coco/vocab.pkl', help='path for vocabulary wrapper')

# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
config = parser.parse_args(args=[])
test(config)
<start> a soccer player kicking a ball on a field . <end>

png