目录
图像搜索已成为一种流行且功能强大的能力,使用户能够通过匹配功能或视觉内容来查找相似的图像。随着计算机视觉和深度学习的快速发展,这种能力得到了极大的增强。
本文主要介绍如何基于矢量数据库MYSCALE来实现图像搜索功能。
MyScale 是一个基于云的数据库,针对 AI 应用程序和解决方案进行了优化,构建在开源 ClickHouse 之上。它有效地管理大量数据,以开发强大的人工智能应用程序。
经过实践python3.7版本可支持后续演示
- pip installdatasets clickhouse-connect
- pip installrequests transformers torch tqdm
这一步主要是将数据转化为向量数据,最终格式为xxx.parquet文件,构建数据集转化数据这一步骤比较耗时且吃机器配置,可以跳过这一步,后续直接下载现成的转化完成的数据集
//下载和处理数据
下载、解压我们需要转化的数据
- wget https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip
- unzip unsplash-research-dataset-lite-latest.zip -d tmp
读取下载数据并转化为Pandas dataframes
- importnumpy asnp
- importpandas aspd
- importglob
-
- documents = ['photos', 'conversions']
- datasets = {}
-
- fordoc indocuments:
- files = glob.glob("tmp/"+ doc + ".tsv*")
- subsets = []
- forfilename infiles:
- df = pd.read_csv(filename, sep='\t', header=0)
- subsets.append(df)
- datasets[doc] = pd.concat(subsets, axis=0, ignore_index=True)
-
- df_photos = datasets['photos']
- df_conversions = datasets['conversions']
定义函数extract_image_features,然后从数据框中选择1000个照片ID,下载对应的图像,调用函数来帮助我们从图像中提取他们的图像嵌入
- importtorch
- fromtransformers importCLIPProcessor, CLIPModel
-
- model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
-
- defextract_image_features(image):
- inputs = processor(images=image, return_tensors="pt")
- withtorch.no_grad():
- outputs = model.get_image_features(**inputs)
- outputs = outputs / outputs.norm(dim=-1, keepdim=True)
- returnoutputs.squeeze(0).tolist()
-
- fromPIL importImage
- importrequests
- fromtqdm.auto importtqdm
-
- # select the first 1000 photo IDs
- photo_ids = df_photos['photo_id'][:1000].tolist()
-
- # create a new data frame with only the selected photo IDs
- df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index(drop=True)
-
- # keep only the columns 'photo_id' and 'photo_image_url' in the data frame
- df_photos = df_photos[['photo_id', 'photo_image_url']]
-
- # add a new column 'photo_embed' to the data frame
- df_photos['photo_embed'] = None
-
- # download the images and extract their embeddings using the 'extract_image_features' function
- fori, row intqdm(df_photos.iterrows(), total=len(df_photos)):
- # construct a URL to download an image with a smaller size by modifying the image URL
- url = row['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
-
- try:
- res = requests.get(url, stream=True).raw
- image = Image.open(res)
- except:
- # remove photo if image download fails
- photo_ids.remove(row['photo_id'])
- continue
-
- # extract feature embedding
- df_photos.at[i, 'photo_embed'] = extract_image_features(image)
//创建数据集
声明两个数据框,一个带有嵌入的照片信息,另一个用于转换信息。
- df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index().rename(columns={'index': 'id'})
-
- df_conversions = df_conversions[df_conversions['photo_id'].isin(photo_ids)].reset_index(drop=True)
- df_conversions = df_conversions[['photo_id', 'keyword']].reset_index().rename(columns={'index': 'id'})
最后将数据帧转化为parquet文件
- importpyarrow aspa
- importpyarrow.parquet aspq
- importnumpy asnp
-
- # create a Table object from the data and schema
- photos_table = pa.Table.from_pandas(df_photos)
- conversion_table = pa.Table.from_pandas(df_conversions)
-
- # write the table to a Parquet file
- pq.write_table(photos_table, 'photos.parquet')
- pq.write_table(conversion_table, 'conversions.parquet')
前面讲到我们可以跳过构建数据集这一步骤,下载已经处理完成的数据集 "https://datasets-server.huggingface.co/splits?dataset=myscale%2Funsplash-examples"
//创建表
在 MyScale 中创建两个表,一个用于照片信息,另一个用于转换信息。
- importclickhouse_connect
-
- # initialize client
- client = clickhouse_connect.get_client(host='YOUR_CLUSTER_HOST', port=8443, username='YOUR_USERNAME', password='YOUR_CLUSTER_PASSWORD')
-
- # drop table if existed
- client.command("DROP TABLE IF EXISTS default.myscale_photos")
- client.command("DROP TABLE IF EXISTS default.myscale_conversions")
-
- # create table for photos
- client.command("""
- CREATE TABLE default.myscale_photos
- (
- id UInt64,
- photo_id String,
- photo_image_url String,
- photo_embed Array(Float32),
- CONSTRAINT vector_len CHECK length(photo_embed) = 512
- )
- ORDER BY id
- """)
-
- # create table for conversions
- client.command("""
- CREATE TABLE default.myscale_conversions
- (
- id UInt64,
- photo_id String,
- keyword String
- )
- ORDER BY id
- """)
上传数据
- fromdatasets importload_dataset
-
- photos = load_dataset("myscale/unsplash-examples", data_files="photos-all.parquet", split="train")
- conversions = load_dataset("myscale/unsplash-examples", data_files="conversions-all.parquet", split="train")
-
- # transform datasets to panda Dataframe
- photo_df = photos.to_pandas()
- conversion_df = conversions.to_pandas()
-
- # convert photo_embed from np array to list
- photo_df['photo_embed'] = photo_df['photo_embed'].apply(lambdax: x.tolist())
-
-
- # initialize client
- client = clickhouse_connect.get_client(host='YOUR_CLUSTER_HOST', port=8443, username='YOUR_USERNAME', password='YOUR_CLUSTER_PASSWORD')
-
-
- # upload data from datasets
- client.insert("default.myscale_photos", photo_df.to_records(index=False).tolist(),
- column_names=photo_df.columns.tolist())
- client.insert("default.myscale_conversions", conversion_df.to_records(index=False).tolist(),
- column_names=conversion_df.columns.tolist())
-
- # check count of inserted data
- print(f"photos count: {client.command('SELECT count(*) FROM default.myscale_photos')}")
- print(f"conversions count: {client.command('SELECT count(*) FROM default.myscale_conversions')}")
-
- # create vector index with cosine
- client.command("""
- ALTER TABLE default.myscale_photos
- ADD VECTOR INDEX photo_embed_index photo_embed
- TYPE MSTG
- ('metric_type=Cosine')
- """)
-
- # check the status of the vector index, make sure vector index is ready with 'Built' status
- get_index_status="SELECT status FROM system.vector_indices WHERE name='photo_embed_index'"
- print(f"index build status: {client.command(get_index_status)}")
基于本地指定的图片查找前K个相似的图像(当前k=10)
- fromdatasets importload_dataset
- importclickhouse_connect
- importrequests
- importmatplotlib.pyplot asplt
- fromPIL importImage
- fromio importBytesIO
- importtorch
- fromtransformers importCLIPProcessor, CLIPModel
-
- model = CLIPModel.from_pretrained(r'C:\Users\16439\Desktop\clip-vit-base-patch32')
- processor = CLIPProcessor.from_pretrained(r"C:\Users\16439\Desktop\clip-vit-base-patch32")
-
- client = clickhouse_connect.get_client(
- host='msc-cab0c439.us-east-1.aws.myscale.com',
- port=8443,
- username='chenzmn',
- password='#隐藏'
- )
-
-
-
-
- defshow_search(image_embed):
- # download image with its url
- defdownload(url):
- response = requests.get(url)
- returnImage.open(BytesIO(response.content))
-
- # define a method to display an online image with a URL
- defshow_image(url, title=None):
- img = download(url)
- fig = plt.figure(figsize=(4, 4))
- plt.imshow(img)
- plt.show()
-
-
-
- # query the database to find the top K similar images to the given image
- top_k = 10
- results = client.query(f"""
- SELECT photo_id, photo_image_url, distance(photo_embed, {image_embed}) as dist
- FROM default.myscale_photos
- ORDER BY dist
- LIMIT {top_k}
- """)
- # WHERE photo_id != '{target_image_id}'
-
- # download the images and add them to a list
- images_url = []
- forr inresults.named_results():
- # construct a URL to download an image with a smaller size by modifying the image URL
- url = r['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
- images_url.append(download(url))
-
- # display candidate images
- print("Loading candidate images...")
- forrow inrange(int(top_k / 5)):
- fig, axs = plt.subplots(1, 5, figsize=(20, 4))
- fori, img inenumerate(images_url[row * 5:row * 5+ 5]):
- axs[i % 5].imshow(img)
- plt.show()
-
-
- defextract_image_features(image):
- inputs = processor(images=image, return_tensors="pt")
- withtorch.no_grad():
- outputs = model.get_image_features(**inputs)
- outputs = outputs / outputs.norm(dim=-1, keepdim=True)
- returnoutputs.squeeze(0).tolist()
-
-
- if__name__ == '__main__':
-
- image = Image.open(r'C:\Users\16439\Desktop\OIP-C.jpg')
- target_image_embed = extract_image_features(image)
- show_search(target_image_embed)
我本地的一张图片:
找到的10张最相似的图片:
这就是全部的演示效果了,感兴趣的朋友也可以自己尝试一下。
作者:陈卓敏 | 后端开发工程师
版权声明:本文由神州数码云基地团队整理撰写,若转载请注明出处。
公众号搜索神州数码云基地,了解更多AI相关技术干货。