• PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)


    需要源码和数据集请点赞关注收藏后评论区留言私信~~~

    一、视觉问题简介

    视觉问答(VQA)是一种同时设计计算机视觉和自然语言处理的学习任务。简单来说,VQA就是对给定的图片进行问答,一个VQA系统以一张图片和一个关于这张图片形式自由,开放式的自然语言问题作为输入,生成一条自然语言答案作为输出,视觉问题系统综合运用到了目前的计算机视觉和自然语言处理的技术,并设计模型设计,实验,以及可视化。

    VQA问题的一种典型模型是联合嵌入模型,这种方法首先学习视觉与自然语言的两个不同模态特征在一个共同的特征空间的嵌入表示,然后根据这种嵌入表示产生回答。

    二、数据集的准备

    1:下载数据

    这里使用VQA2.0数据集进行训练和验证,VQA2.0是一个公认有难度,并且语言验证得到了有效控制的数据集

    本次使用到的图片为MSCOCO数据集中train2014子集和val2014子集,图片可以在官网下载

    数据集网址

    本次用到的图像特征是由目标检测网络Faster-RCNN检测并生成的,可评论区留言私信博主要

    2:安装依赖

    确保安装好PyTorch,然后在程序目录下运行pip install -r requirements.txt安装其他依赖项

    三、关键模块简介

    1:FCnet模块

    FCnet即一系列的全连接层,各个层的输入输出大小在模块构建时给出,这个模块默认使其中的全连接层具有bias,并以ReLU作为激活函数 并使用weight normalization

    2:SimpleClassifier模块

    它的作用是:在视觉问答系统的末端,根据融合的特征得到最终答案

    3:问题嵌入模块

    在联合嵌入模型中,需要使用RNN将输入的问题编码成向量,LSTM和GRU使两种代表性的RNN,由于实践中GRU与LSTM表现相近且占用显存较少,所以这里选用GRU

    4:词嵌入

    要获得问题句子的嵌入表示,首先应该获得词嵌入表示,每一个词需要用一个唯一的数字表示

    baseline代码如下

    1. import torch
    2. import torch.nn as nn
    3. from lib.module import topdown_attention
    4. from lib.module.language_model import WordEmbedding, QuestionEmbedding
    5. from lib.module.classifier import SimpleClassifier
    6. from lib.module.fc import FCNet
    7. class Baseline(nn.Module):
    8. def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifer, need_internals=False):
    9. super(Baseline, self).__init__()
    10. self.need_internals = need_internals
    11. self.w_emb = w_emb
    12. self.q_emb = q_emb
    13. self.v_att = v_att
    14. self.q_net = q_net
    15. self.v_net = v_net
    16. self.classifier = classifer
    17. def forward(self, q_tokens, ent_features):
    18. w_emb = self.w_emb(q_tokens)
    19. q_emb = self.q_emb(w_emb)
    20. att = self.v_att(q_emb, ent_features) # [ B, n_ent, 1 ]
    21. v_emb = (att * ent_features).sum(1) # [ B, hid_dim ]
    22. internals = [att.squeeze()] if self.need_internals else None
    23. q_repr = self.q_net(q_emb)
    24. v_repr = self.v_net(v_emb)
    25. joint_repr = q_repr * v_repr
    26. logits = self.classifier(joint_repr)
    27. return logits, internals
    28. @classmethod
    29. def build_from_config(cls, cfg, dataset, need_internals):
    30. w_emb = WordEmbedding(dataset.word_dict.n_tokens, cfg.lm.word_emb_dim, 0.0)
    31. q_emb = QuestionEmbedding(cfg.lm.word_emb_dim, cfg.hid_dim, cfg.lm.n_layers, cfg.lm.bidirectional, cfg.lm.dropout, cfg.lm.rnn_type)
    32. q_dim = cfg.hid_dim
    33. att_cls = topdown_attention.classes[cfg.topdown_att.type]
    34. v_att = att_cls(1, q_dim, cfg.ent_dim, cfg.topdown_att.hid_dim, cfg.topdown_att.dropout)
    35. q_net = FCNet([q_dim, cfg.hid_dim])
    36. v_net = FCNet([cfg.ent_dim, cfg.hid_dim])
    37. classifier = SimpleClassifier(cfg.hid_dim, cfg.mlp.hid_dim, dataset.ans_dict.n_tokens, cfg.mlp.dropout)
    38. return cls(w_emb, q_emb, v_att, q_net, v_net, classifier, need_internals)

    数据集目录如下

    四、结果可视化

    读取了之前训练好的模型之后,使用数据为配置文件中的val,程序运行完成后结果可视化如下

    机器对于给出的图片会输出对于的问答结果

     

     五、代码

    部分代码如下

    训练类

    1. import os
    2. import time
    3. import torch
    4. import torch.nn as nn
    5. import torch.nn.functional as F
    6. from torch.optim.lr_scheduler import LambdaLR
    7. from torch.nn.utils import clip_grad_norm_
    8. from bisect import bisect
    9. from tqdm import tqdm
    10. def bce_with_logits(logits, labels):
    11. assert logits.dim() == 2
    12. loss = F.binary_cross_entropy_with_logits(logits, labels)
    13. loss *= labels.size(1) # multiply by number of QAs
    14. return loss
    15. def sce_with_logits(logits, labels):
    16. assert logits.dim() == 2
    17. loss = F.cross_entropy(logits, labels.nonzero()[:, 1])
    18. loss *= labels.size(1)
    19. return loss
    20. def compute_score_with_logits(logits, labels):
    21. with torch.no_grad():
    22. logits = torch.max(logits, 1)[1] # argmax
    23. one_hots = torch.zeros(*labels.size()).cuda()
    24. one_hots.scatter_(1, logits.view(-1, 1), 1)
    25. scores = (one_hots * labels)
    26. return scores
    27. def lr_schedule_func_builder(cfg):
    28. def func(step_idx):
    29. if step_idx <= cfg.train.warmup_steps:
    30. alpha = float(step_idx) / float(cfg.train.warmup_steps)
    31. return cfg.train.warmup_factor * (1. - alpha) + alpha
    32. else:
    33. idx = bisect(cfg.train.lr_steps, step_idx)
    34. return pow(cfg.train.lr_ratio, idx)
    35. return func
    36. def train(model, cfg, train_loader, val_loader, n_epochs, val_freq, out_dir):
    37. os.makedirs(out_dir, exist_ok=True)
    38. optim = torch.optim.Adamax(model.parameters(), **cfg.train.optim)
    39. n_train_batches = len(train_loader)
    40. train_score = 0.0
    41. loss_fn = bce_with_logits if cfg.model.loss == "logistic" else sce_with_logits
    42. for epoch in range(n_epochs):
    43. epoch_loss = 0.0
    44. tic_0 = time.time()
    45. for i, data in enumerate(train_loader):
    46. tic_1 = time.time()
    47. q_tokens = data[2].cuda()
    48. a_targets = data[3].cuda()
    49. v_features = [_.cuda() for _ in data[4:]]
    50. tic_2 = time.time()
    51. optim.zero_grad()
    52. logits, _ = model(q_tokens, *v_features)
    53. loss = loss_fn(logits, a_targets)
    54. tic_3 = time.time()
    55. loss.backward()
    56. if cfg.train.clip_grad: clip_grad_norm_(model.parameters(), cfg.train.max_grad_norm)
    57. optim.step()
    58. tic_4 = time.time()
    59. batch_score = compute_score_with_logits(logits, a_targets).sum()
    60. epoch_loss += float(loss.data.item() * logits.size(0))
    61. train_score += float(batch_score)
    62. del loss
    63. logstr = "epoch %2d batch %4d/%4d | ^ %4dms | => %4dms | <= %4dms" % \
    64. (epoch + 1, i + 1, n_train_batches, 1000*(tic_2-tic_0), 1000*(tic_3-tic_2), 1000*(tic_4-tic_3))
    65. print("%-80s" % logstr, end="\r")
    66. tic_0 = time.time()
    67. epoch_loss /= len(train_loader.dataset)
    68. train_score = 100 * train_score / len(train_loader.dataset)
    69. logstr = "epoch %2d | train_loss: %5.2f train_score: %5.2f" % (epoch + 1, epoch_loss, train_score)
    70. if (epoch + 1) % val_freq == 0:
    71. model.eval()
    72. val_score, upper_bound = validate(model, val_loader)
    73. model.train()
    74. logstr += " | val_score: %5.2f (%5.2f)" % (100 * val_score, 100 * upper_bound)
    75. print("%-80s" % logstr)
    76. model_path = os.path.join(out_dir, 'model_%d.pth' % (epoch + 1))
    77. torch.save(model.state_dict(), model_path)
    78. def validate(model, loader):
    79. score = 0
    80. upper_bound = 0
    81. n_qas = 0
    82. with torch.no_grad():
    83. for i, data in enumerate(loader):
    84. q_tokens = data[2].cuda()
    85. a_targets = data[3].cuda()
    86. v_features = [_.cuda() for _ in data[4:]]
    87. logits, _ = model(q_tokens, *v_features)
    88. batch_score = compute_score_with_logits(logits, a_targets)
    89. score += batch_score.sum()
    90. upper_bound += (a_targets.max(1)[0]).sum()
    91. n_qas += logits.size(0)
    92. logstr = "val batch %5d/%5d" % (i + 1, len(loader))
    93. print("%-80s" % logstr, end='\r')
    94. score = score / n_qas
    95. upper_bound = upper_bound / n_qas
    96. return score, upper_bound

    infer类

    1. import os
    2. import time
    3. import json
    4. import torch
    5. import cv2
    6. import shutil
    7. import numpy as np
    8. import torch.nn as nn
    9. import torch.nn.functional as F
    10. from tqdm import tqdm
    11. colors = [ (175, 84, 65), (68, 194, 246), (136, 147, 65), (92, 192, 151) ]
    12. def attention_map(im, boxes, atts, p=0.8, bgc=1.0, compress=0.85, box_color=(65, 81, 226)):
    13. height, width, channel = im.shape
    14. im = im / 255.0
    15. att_map = np.zeros([height, width])
    16. boxes = boxes.astype(np.int)
    17. for box, att in zip(boxes, atts):
    18. x1, y1, x2, y2 = box
    19. roi = att_map[y1:y2, x1:x2]
    20. roi[roi < att] = att
    21. att_map /= att_map.max()
    22. att_map = att_map ** p
    23. att_map = att_map * compress + (1-compress)
    24. att_map = cv2.resize(att_map, (int(width/16), int(height/16)))
    25. att_map = cv2.resize(att_map, (width, height))
    26. att_map = np.expand_dims(att_map, axis=2)
    27. bg = np.ones_like(att_map) * bgc
    28. att_im = im * att_map + bg * (1-att_map)
    29. att_im = (att_im * 255).astype(np.uint8)
    30. center = np.argmax(atts)
    31. x1, y1, x2, y2 = boxes[center]
    32. cv2.rectangle(att_im, (x1, y1), (x2, y2), box_color, 5)
    33. return att_im
    34. def infer_visualize(model, args, cfg, ans_dict, loader):
    35. _, ckpt = os.path.split(args.checkpoint)
    36. ckpt, _ = os.path.splitext(ckpt)
    37. out_dir = os.path.join(args.out_dir, "%s_%s_%s_visualization" % (args.cfg_name, ckpt, args.data))
    38. os.makedirs(out_dir, exist_ok=True)
    39. model.eval()
    40. questions_path = cfg.data[args.data].composition[0].q_jsons[0]
    41. questions = json.load(open(questions_path))
    42. pbar = tqdm(total=args.n_batches * loader.batch_size)
    43. with torch.no_grad():
    44. for i, data in enumerate(loader):
    45. if i == args.n_batches: break
    46. question_ids = data[0]
    47. image_ids = data[1]
    48. q_tokens = data[2].cuda()
    49. obj_featuers = data[4].cuda()
    50. batch_boxes = data[5].numpy()
    51. logits, internals = model(q_tokens, obj_featuers)
    52. topdown_atts = internals[0]
    53. topdown_atts = topdown_atts.data.cpu().numpy()
    54. _, predictions = logits.max(dim=1)
    55. for idx in range(len(question_ids)):
    56. question_id = question_ids[idx]
    57. image_id = image_ids[idx]
    58. boxes = batch_boxes[idx]
    59. answer = ans_dict.idx2ans[predictions[idx]]
    60. q_entry = questions[question_id]
    61. topdown_att = topdown_atts[idx]
    62. question = q_entry["question"]
    63. gts = list(q_entry["answers"].items())
    64. gts = sorted(gts, reverse=True, key=lambda x: x[1])
    65. gt = gts[0][0]
    66. q_out_dir = os.path.join(out_dir, question_id)
    67. os.makedirs(q_out_dir, exist_ok=True)
    68. q_str = question + "\n" + "gt: %s\n" % gt + "answer: %s\n" % answer
    69. with open(os.path.join(q_out_dir, "qa.txt"), "w") as f: f.write(q_str)
    70. image_path = os.path.join(args.images_dir, "%s.jpg" % image_id)
    71. shutil.copy(image_path, os.path.join(q_out_dir, "original.jpg"))
    72. im = cv2.imread(image_path)
    73. att_map = attention_map(im.copy(), boxes, topdown_att)
    74. cv2.imwrite(os.path.join(q_out_dir, "topdown_att.jpg"), att_map)
    75. pbar.update(1)

    创作不易 觉得有帮助请点赞关注收藏~~~

  • 相关阅读:
    面试题-01
    工控机通过Profinet转Modbus RTU网关连接变频器与电机通讯案例
    .NET桌面程序集成Web网页开发的十种解决方案
    软件测试工程师需要掌握哪些技能呢?
    逻辑扇区和物理扇区
    代码随想录算法训练营第四十五天丨 动态规划part08
    2023年【广东省安全员C证第四批(专职安全生产管理人员)】报名考试及广东省安全员C证第四批(专职安全生产管理人员)最新解析
    鲲鹏代码迁移工具介绍
    const对象竟然可以修改其成员,只因顶层const
    【Glide 】框架在内存使用方面的优化
  • 原文地址:https://blog.csdn.net/jiebaoshayebuhui/article/details/128024996