• AIGC|一文揭秘如何利用MYSCALE实现高效图像搜索?


    目录

    一、MySCALE简介

    二、实践演示

    (一)下载依赖

    (二)构建数据集

    (三)将数据填充到MYSCALE数据库


    图像搜索已成为一种流行且功能强大的能力,使用户能够通过匹配功能或视觉内容来查找相似的图像。随着计算机视觉和深度学习的快速发展,这种能力得到了极大的增强。

    本文主要介绍如何基于矢量数据库MYSCALE来实现图像搜索功能。

    一、MySCALE简介

    MyScale 是一个基于云的数据库,针对 AI 应用程序和解决方案进行了优化,构建在开源 ClickHouse 之上。它有效地管理大量数据,以开发强大的人工智能应用程序。

    • 专为 AI 应用程序构建:在单个平台中管理和支持用于 AI 应用程序的结构化和矢量化数据的分析处理。
    • 专为性能而构建:先进的 OLAP 数据库架构,以令人难以置信的性能对矢量化数据执行操作。
    • 专为通用可访问性而构建:SQL 是 MyScale 所需的唯一编程语言。这使得MyScale与定制API相比更有利,并且适合大型编程社区。

    二、实践演示

    (一)下载依赖

    经过实践python3.7版本可支持后续演示

    1. pip installdatasets clickhouse-connect
    2. pip installrequests transformers torch tqdm

    (二)构建数据集

    这一步主要是将数据转化为向量数据,最终格式为xxx.parquet文件,构建数据集转化数据这一步骤比较耗时且吃机器配置,可以跳过这一步,后续直接下载现成的转化完成的数据集

    //下载和处理数据

    下载、解压我们需要转化的数据

    1. wget https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip
    2. unzip unsplash-research-dataset-lite-latest.zip -d tmp

    读取下载数据并转化为Pandas dataframes

    1. importnumpy asnp
    2. importpandas aspd
    3. importglob
    4. documents = ['photos', 'conversions']
    5. datasets = {}
    6. fordoc indocuments:
    7. files = glob.glob("tmp/"+ doc + ".tsv*")
    8. subsets = []
    9. forfilename infiles:
    10. df = pd.read_csv(filename, sep='\t', header=0)
    11. subsets.append(df)
    12. datasets[doc] = pd.concat(subsets, axis=0, ignore_index=True)
    13. df_photos = datasets['photos']
    14. df_conversions = datasets['conversions']

    定义函数extract_image_features,然后从数据框中选择1000个照片ID,下载对应的图像,调用函数来帮助我们从图像中提取他们的图像嵌入

    1. importtorch
    2. fromtransformers importCLIPProcessor, CLIPModel
    3. model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
    4. processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    5. defextract_image_features(image):
    6. inputs = processor(images=image, return_tensors="pt")
    7. withtorch.no_grad():
    8. outputs = model.get_image_features(**inputs)
    9. outputs = outputs / outputs.norm(dim=-1, keepdim=True)
    10. returnoutputs.squeeze(0).tolist()
    11. fromPIL importImage
    12. importrequests
    13. fromtqdm.auto importtqdm
    14. # select the first 1000 photo IDs
    15. photo_ids = df_photos['photo_id'][:1000].tolist()
    16. # create a new data frame with only the selected photo IDs
    17. df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index(drop=True)
    18. # keep only the columns 'photo_id' and 'photo_image_url' in the data frame
    19. df_photos = df_photos[['photo_id', 'photo_image_url']]
    20. # add a new column 'photo_embed' to the data frame
    21. df_photos['photo_embed'] = None
    22. # download the images and extract their embeddings using the 'extract_image_features' function
    23. fori, row intqdm(df_photos.iterrows(), total=len(df_photos)):
    24. # construct a URL to download an image with a smaller size by modifying the image URL
    25. url = row['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
    26. try:
    27. res = requests.get(url, stream=True).raw
    28. image = Image.open(res)
    29. except:
    30. # remove photo if image download fails
    31. photo_ids.remove(row['photo_id'])
    32. continue
    33. # extract feature embedding
    34. df_photos.at[i, 'photo_embed'] = extract_image_features(image)

    //创建数据集

    声明两个数据框,一个带有嵌入的照片信息,另一个用于转换信息。

    1. df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index().rename(columns={'index': 'id'})
    2. df_conversions = df_conversions[df_conversions['photo_id'].isin(photo_ids)].reset_index(drop=True)
    3. df_conversions = df_conversions[['photo_id', 'keyword']].reset_index().rename(columns={'index': 'id'})

    最后将数据帧转化为parquet文件

    1. importpyarrow aspa
    2. importpyarrow.parquet aspq
    3. importnumpy asnp
    4. # create a Table object from the data and schema
    5. photos_table = pa.Table.from_pandas(df_photos)
    6. conversion_table = pa.Table.from_pandas(df_conversions)
    7. # write the table to a Parquet file
    8. pq.write_table(photos_table, 'photos.parquet')
    9. pq.write_table(conversion_table, 'conversions.parquet')

    (三)将数据填充到MYSCALE数据库

    前面讲到我们可以跳过构建数据集这一步骤,下载已经处理完成的数据集 "https://datasets-server.huggingface.co/splits?dataset=myscale%2Funsplash-examples"

    //创建表

    在 MyScale 中创建两个表,一个用于照片信息,另一个用于转换信息。

    1. importclickhouse_connect
    2. # initialize client
    3. client = clickhouse_connect.get_client(host='YOUR_CLUSTER_HOST', port=8443, username='YOUR_USERNAME', password='YOUR_CLUSTER_PASSWORD')
    4. # drop table if existed
    5. client.command("DROP TABLE IF EXISTS default.myscale_photos")
    6. client.command("DROP TABLE IF EXISTS default.myscale_conversions")
    7. # create table for photos
    8. client.command("""
    9. CREATE TABLE default.myscale_photos
    10. (
    11. id UInt64,
    12. photo_id String,
    13. photo_image_url String,
    14. photo_embed Array(Float32),
    15. CONSTRAINT vector_len CHECK length(photo_embed) = 512
    16. )
    17. ORDER BY id
    18. """)
    19. # create table for conversions
    20. client.command("""
    21. CREATE TABLE default.myscale_conversions
    22. (
    23. id UInt64,
    24. photo_id String,
    25. keyword String
    26. )
    27. ORDER BY id
    28. """)

    上传数据

    1. fromdatasets importload_dataset
    2. photos = load_dataset("myscale/unsplash-examples", data_files="photos-all.parquet", split="train")
    3. conversions = load_dataset("myscale/unsplash-examples", data_files="conversions-all.parquet", split="train")
    4. # transform datasets to panda Dataframe
    5. photo_df = photos.to_pandas()
    6. conversion_df = conversions.to_pandas()
    7. # convert photo_embed from np array to list
    8. photo_df['photo_embed'] = photo_df['photo_embed'].apply(lambdax: x.tolist())
    9. # initialize client
    10. client = clickhouse_connect.get_client(host='YOUR_CLUSTER_HOST', port=8443, username='YOUR_USERNAME', password='YOUR_CLUSTER_PASSWORD')
    11. # upload data from datasets
    12. client.insert("default.myscale_photos", photo_df.to_records(index=False).tolist(),
    13. column_names=photo_df.columns.tolist())
    14. client.insert("default.myscale_conversions", conversion_df.to_records(index=False).tolist(),
    15. column_names=conversion_df.columns.tolist())
    16. # check count of inserted data
    17. print(f"photos count: {client.command('SELECT count(*) FROM default.myscale_photos')}")
    18. print(f"conversions count: {client.command('SELECT count(*) FROM default.myscale_conversions')}")
    19. # create vector index with cosine
    20. client.command("""
    21. ALTER TABLE default.myscale_photos
    22. ADD VECTOR INDEX photo_embed_index photo_embed
    23. TYPE MSTG
    24. ('metric_type=Cosine')
    25. """)
    26. # check the status of the vector index, make sure vector index is ready with 'Built' status
    27. get_index_status="SELECT status FROM system.vector_indices WHERE name='photo_embed_index'"
    28. print(f"index build status: {client.command(get_index_status)}")

    基于本地指定的图片查找前K个相似的图像(当前k=10)

    1. fromdatasets importload_dataset
    2. importclickhouse_connect
    3. importrequests
    4. importmatplotlib.pyplot asplt
    5. fromPIL importImage
    6. fromio importBytesIO
    7. importtorch
    8. fromtransformers importCLIPProcessor, CLIPModel
    9. model = CLIPModel.from_pretrained(r'C:\Users\16439\Desktop\clip-vit-base-patch32')
    10. processor = CLIPProcessor.from_pretrained(r"C:\Users\16439\Desktop\clip-vit-base-patch32")
    11. client = clickhouse_connect.get_client(
    12. host='msc-cab0c439.us-east-1.aws.myscale.com',
    13. port=8443,
    14. username='chenzmn',
    15. password='#隐藏'
    16. )
    17. defshow_search(image_embed):
    18. # download image with its url
    19. defdownload(url):
    20. response = requests.get(url)
    21. returnImage.open(BytesIO(response.content))
    22. # define a method to display an online image with a URL
    23. defshow_image(url, title=None):
    24. img = download(url)
    25. fig = plt.figure(figsize=(4, 4))
    26. plt.imshow(img)
    27. plt.show()
    28. # query the database to find the top K similar images to the given image
    29. top_k = 10
    30. results = client.query(f"""
    31. SELECT photo_id, photo_image_url, distance(photo_embed, {image_embed}) as dist
    32. FROM default.myscale_photos
    33. ORDER BY dist
    34. LIMIT {top_k}
    35. """)
    36. # WHERE photo_id != '{target_image_id}'
    37. # download the images and add them to a list
    38. images_url = []
    39. forr inresults.named_results():
    40. # construct a URL to download an image with a smaller size by modifying the image URL
    41. url = r['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
    42. images_url.append(download(url))
    43. # display candidate images
    44. print("Loading candidate images...")
    45. forrow inrange(int(top_k / 5)):
    46. fig, axs = plt.subplots(1, 5, figsize=(20, 4))
    47. fori, img inenumerate(images_url[row * 5:row * 5+ 5]):
    48. axs[i % 5].imshow(img)
    49. plt.show()
    50. defextract_image_features(image):
    51. inputs = processor(images=image, return_tensors="pt")
    52. withtorch.no_grad():
    53. outputs = model.get_image_features(**inputs)
    54. outputs = outputs / outputs.norm(dim=-1, keepdim=True)
    55. returnoutputs.squeeze(0).tolist()
    56. if__name__ == '__main__':
    57. image = Image.open(r'C:\Users\16439\Desktop\OIP-C.jpg')
    58. target_image_embed = extract_image_features(image)
    59. show_search(target_image_embed)

    我本地的一张图片:

    找到的10张最相似的图片:

    这就是全部的演示效果了,感兴趣的朋友也可以自己尝试一下。

    作者:陈卓敏 | 后端开发工程师

    版权声明:本文由神州数码云基地团队整理撰写,若转载请注明出处。

    公众号搜索神州数码云基地,了解更多AI相关技术干货。

  • 相关阅读:
    github.com不能访问原因不是因为DNS,而是因为故意间歇性抽风,可改用镜像站
    GMV远超预期背后,快手电商做对了什么?
    [axios]使用指南
    使用springboot实现jsonp|jsonp的实现|JSONP的实现使用springboot
    Synchronized代码详解?
    机器学习笔记:初始化0的问题
    Promethus实操部署ARM架构 麒麟系统
    利用OPNET进行网络仿真时网络层协议(以QoS为例)的使用、配置及注意点
    TCP 三次握手和四次挥手机制,TCP为什么要三次握手和四次挥手,TCP 连接建立失败处理机制
    性能测试 —— 性能测试常见的测试指标 !
  • 原文地址:https://blog.csdn.net/CBGCampus/article/details/133939757