• 深度学习之使用Milvus向量数据库实战图搜图


    1. import torch
    2. from torchvision import models,transforms
    3. from torch.utils.data import Dataset , DataLoader
    4. import os
    5. import pickle
    6. from PIL import Image
    7. from tqdm import tqdm
    8. from pymilvus import (
    9. FieldSchema,
    10. DataType,
    11. db,
    12. connections,
    13. CollectionSchema,
    14. Collection
    15. )
    16. import time
    17. import matplotlib.pyplot as plt
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    1. transform = transforms.Compose([
    2. transforms.Resize((256, 256)),
    3. transforms.CenterCrop(224),
    4. transforms.ToTensor(),
    5. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    6. ])
    1. image_dir = "./flower_data/train"
    2. image_dirs = [f"{p}/{n}" for p , n in zip([image_dir] * 102 , os.listdir(image_dir))]
    1. image_paths = []
    2. for dir in image_dirs:
    3. names = os.listdir(dir)
    4. for name in names:
    5. image_paths.append(os.path.join(dir,name))
    image_paths
    image_dirs
    
    1. with open("image_paths.pkl" , "wb" ) as fw:
    2. pickle.dump(image_paths, fw)
    1. class ImageDataset(Dataset):
    2. def __init__(self , transform =None):
    3. super().__init__()
    4. self.transform = transform
    5. with open("./image_paths.pkl", "rb") as fr:
    6. self.data_paths = pickle.load(fr)
    7. self.data = []
    8. for image_path in self.data_paths:
    9. img = Image.open(image_path)
    10. if img.mode == "RGB":
    11. self.data.append(image_path)
    12. def __len__(self):
    13. return len(self.data)
    14. def __getitem__(self, index):
    15. image_path = self.data[index]
    16. img = Image.open(image_path)
    17. if self.transform:
    18. img = self.transform(img)
    19. dict_data = {
    20. "idx" : index,
    21. "image_path" : image_path,
    22. "img" : img
    23. }
    24. return dict_data
    valid_dataset = ImageDataset(transform=transform)
    
    len(valid_dataset)
    
    valid_dataloader = DataLoader(valid_dataset , batch_size=64, shuffle=False)
    
    1. def load_model():
    2. model = models.resnet18(pretrained = True)
    3. model.to(device)
    4. model.eval()
    5. return model
    model = load_model()
    
    model
    ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (fc): Linear(in_features=512, out_features=1000, bias=True)
    )

    1. def feature_extract(model, x):
    2. x = model.conv1(x)
    3. x = model.bn1(x)
    4. x = model.relu(x)
    5. x = model.maxpool(x)
    6. x = model.layer1(x)
    7. x = model.layer2(x)
    8. x = model.layer3(x)
    9. x = model.layer4(x)
    10. x = model.avgpool(x)
    11. x = torch.flatten(x, 1)
    12. return x
    1. feature_list = []
    2. feature_index_list = []
    3. feature_image_path_list = []
    4. for idx , batch in enumerate(tqdm(valid_dataloader)):
    5. imgs = batch["img"]
    6. indexs = batch["idx"]
    7. image_paths = batch["image_path"]
    8. img = imgs.to(device)
    9. feature = feature_extract(model, img)
    10. feature = feature.data.cpu().numpy()
    11. feature_list.extend(feature)
    12. feature_index_list.extend(indexs)
    13. feature_image_path_list.extend(image_paths)
    1. entities = [
    2. feature_image_path_list,
    3. feature_list
    4. ]
    len(feature_list)
    
    entities[0]
    
    1. fields = [
    2. FieldSchema(name="image_path" ,dtype=DataType.VARCHAR, description="图片路径", max_length = 512 , is_primary=True, auto_id=False),
    3. FieldSchema(name="embeddings" , dtype=DataType.FLOAT_VECTOR,description="向量表示图片" , is_primary=False,dim=512)
    4. ]
    5. schema = CollectionSchema(fields,description="用于图生图的表")
    connections.connect("power_image_search",host="ljxwtl.cn",port=19530,db_name="power_image_search")
    
    table = Collection("image_to_image", schema=schema,consistency_level="Strong",using="power_image_search")
    
    1. for idx , image_path in enumerate(feature_image_path_list):
    2. entity = [
    3. [feature_image_path_list[idx]],
    4. [feature_list[idx]]
    5. ]
    6. table.insert(entity)
    table.flush()
    table.num_entities
    

    6552

    1. index = {
    2. "index_type": "IVF_FLAT",
    3. "metric_type": "L2",
    4. "params": {"nlist": 128},
    5. }
    table.create_index("embeddings",index_params=index)
    
    table.load()
    
    1. vectors_to_search = entities[-1][1:2]
    2. search_params = {
    3. "metric_type": "L2",
    4. "params": {"nprobe": 10},
    5. }
    1. start_time = time.time()
    2. result = table.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["image_path"])
    3. end_time = time.time()
    1. for hits in result:
    2. for hit in hits:
    3. print(f"hit: {hit}, image_path field: {hit.entity.get('image_path')}")

     

    1. img_data = plt.imread(entities[0][1])
    2. plt.imshow(img_data)
    3. plt.show()
    1. img_data = plt.imread("./flower_data/train/1\\image_06766.jpg")
    2. plt.imshow(img_data)
    3. plt.show()

  • 相关阅读:
    爬虫爬取mp3文件例子
    2024牛客暑期多校训练营7
    容器 —— 背景知识
    Oracle/PLSQL: Replace Function
    Spring Boot常用注解@ConfigurationProperties、松散绑定、数据校验
    Mysql与Oracle分页查询差异
    SpringCloud中服务间通信方式以及Ribbon、Openfeign组件的使用
    使用QEMU+GDB调试操作系统代码
    在openSUSE-Leap-15.4-DVD-x86_64下安装网易云音乐linux客户端
    nodeJs 实现视频的转换(超详细教程)
  • 原文地址:https://blog.csdn.net/wtl1992/article/details/134493014