- import torch
- from torchvision import models,transforms
- from torch.utils.data import Dataset , DataLoader
- import os
- import pickle
- from PIL import Image
- from tqdm import tqdm
- from pymilvus import (
- FieldSchema,
- DataType,
- db,
- connections,
- CollectionSchema,
- Collection
- )
- import time
- import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- transform = transforms.Compose([
- transforms.Resize((256, 256)),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- image_dir = "./flower_data/train"
- image_dirs = [f"{p}/{n}" for p , n in zip([image_dir] * 102 , os.listdir(image_dir))]
- image_paths = []
- for dir in image_dirs:
- names = os.listdir(dir)
- for name in names:
- image_paths.append(os.path.join(dir,name))
image_paths
image_dirs
- with open("image_paths.pkl" , "wb" ) as fw:
- pickle.dump(image_paths, fw)
- class ImageDataset(Dataset):
- def __init__(self , transform =None):
- super().__init__()
- self.transform = transform
- with open("./image_paths.pkl", "rb") as fr:
- self.data_paths = pickle.load(fr)
-
- self.data = []
-
- for image_path in self.data_paths:
- img = Image.open(image_path)
- if img.mode == "RGB":
- self.data.append(image_path)
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, index):
- image_path = self.data[index]
- img = Image.open(image_path)
-
- if self.transform:
- img = self.transform(img)
-
- dict_data = {
- "idx" : index,
- "image_path" : image_path,
- "img" : img
- }
- return dict_data
valid_dataset = ImageDataset(transform=transform)
len(valid_dataset)
valid_dataloader = DataLoader(valid_dataset , batch_size=64, shuffle=False)
- def load_model():
- model = models.resnet18(pretrained = True)
- model.to(device)
- model.eval()
- return model
model = load_model()
model
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) )
- def feature_extract(model, x):
- x = model.conv1(x)
- x = model.bn1(x)
- x = model.relu(x)
- x = model.maxpool(x)
- x = model.layer1(x)
- x = model.layer2(x)
- x = model.layer3(x)
- x = model.layer4(x)
- x = model.avgpool(x)
- x = torch.flatten(x, 1)
- return x
- feature_list = []
- feature_index_list = []
- feature_image_path_list = []
- for idx , batch in enumerate(tqdm(valid_dataloader)):
- imgs = batch["img"]
- indexs = batch["idx"]
- image_paths = batch["image_path"]
- img = imgs.to(device)
- feature = feature_extract(model, img)
- feature = feature.data.cpu().numpy()
- feature_list.extend(feature)
- feature_index_list.extend(indexs)
- feature_image_path_list.extend(image_paths)
- entities = [
- feature_image_path_list,
- feature_list
- ]
len(feature_list)
entities[0]
- fields = [
- FieldSchema(name="image_path" ,dtype=DataType.VARCHAR, description="图片路径", max_length = 512 , is_primary=True, auto_id=False),
- FieldSchema(name="embeddings" , dtype=DataType.FLOAT_VECTOR,description="向量表示图片" , is_primary=False,dim=512)
- ]
- schema = CollectionSchema(fields,description="用于图生图的表")
connections.connect("power_image_search",host="ljxwtl.cn",port=19530,db_name="power_image_search")
table = Collection("image_to_image", schema=schema,consistency_level="Strong",using="power_image_search")
- for idx , image_path in enumerate(feature_image_path_list):
- entity = [
- [feature_image_path_list[idx]],
- [feature_list[idx]]
- ]
- table.insert(entity)
table.flush()
table.num_entities
6552
- index = {
- "index_type": "IVF_FLAT",
- "metric_type": "L2",
- "params": {"nlist": 128},
- }
table.create_index("embeddings",index_params=index)
table.load()
- vectors_to_search = entities[-1][1:2]
- search_params = {
- "metric_type": "L2",
- "params": {"nprobe": 10},
- }
- start_time = time.time()
- result = table.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["image_path"])
- end_time = time.time()
- for hits in result:
- for hit in hits:
- print(f"hit: {hit}, image_path field: {hit.entity.get('image_path')}")
- img_data = plt.imread(entities[0][1])
- plt.imshow(img_data)
- plt.show()
- img_data = plt.imread("./flower_data/train/1\\image_06766.jpg")
- plt.imshow(img_data)
- plt.show()