• ​Segment-and-Track Anything——通用智能视频分割、跟踪、编辑算法解读与源码部署


    一、 万物分割

    随着Meta发布的Segment Anything Model (万物分割)的论文并开源了相关的算法,我们可以从中看到,SAM与GPT-4类似,这篇论文的目标是(零样本)分割一切,将自然语言处理(NLP)的提示范式引入了计算机视觉(CV)领域,为CV基础模型提供了更广泛的支持和深度研究的机会。
    Segment Anything与传统的图像分割有两个很大的区别:

    1、数据收集和主动学习的方式。

    对于一个庞大的数据集,例如包含十亿组数据的情况,标注全部数据几乎是不可行的。因此,一个解决方案是采用主动学习的方法。这种方法可以分为以下步骤:
    初步标注: 首先,对数据集的一部分进行手动标注。这可以是一个小样本,但应涵盖多种情况和类别,以确保模型获得足够的多样性。
    半监督学习: 接下来,使用已标注的数据来训练一个初始模型。这个模型可以用来预测未标注数据的标签。
    人工校验与修正: 模型生成的预测标签需要经过人工校验和修正,以确保其准确性。这可以通过专业人员或者众包的方式来完成。
    迭代循环: 重复上述步骤,逐渐扩展已标注数据的数量。每次迭代都会提高模型的性能,因为它可以在更多数据上进行训练。
    通过这种方式,可以逐步提高数据集的标注质量,而不需要手动标注所有数据。当数据集足够大并且模型被训练到一定程度时,其性能将会显著提升。

    2. prompt

    Segment Anything 引入了prompt的概念。Prompt是一种用户输入的提示,用于引导模型生成特定类型的回复。这在像GPT-3和SAM这样的模型中非常有用。用户可以提供一个问题或者描述,以帮助模型理解其意图并生成相关的回答或操作。
    例如,在SAM中,你可以输入一个提示词,如“Cat”或“Dog”,以告诉模型你希望它分割出照片中的猫或狗。模型将自动检测并绘制框,以实现分割。这个提示词可以用来限定模型的任务,使其更专注于特定的信息提取或操作。
    在这里插入图片描述
    这两个概念都是在处理大规模数据和提高模型性能方面非常重要的工程性工作。通过合理的数据收集和主动学习策略,以及通过引导模型的prompt,可以更好地满足用户需求,提高模型的效果,并逐步改进模型的性能。

    二、​Segment-and-Track Anything

    1、算法简介

    SAM的出现统一了分割这个任务很多应用,也表明了在CV领域可能存在大规模模型的潜力。这一突破肯定会对CV领域的研究带来重大变革,许多任务将得到统一处理。这一新的数据集和范式结合了超强的零样本泛化能力,将对CV领域产生深远影响。但缺乏对视频数据的支持。随后,浙江大学ReLER实验室的科研人员在最新开源的SAM-Track项目其中,解锁了SAM的视频分割能力,即:分割并跟踪一切(Segment-and-track anything,SAM-track)。SAM-Track在单卡上即可支持各种时空场景中的目标分割和跟踪,包括街景、AR、细胞、动画、航拍等,可同时追踪超过200个物体,为用户提供了强大的视频编辑能力。“Segment and Track Anything” 利用自动和交互式方法。主要使用的算法包括 SAM(Segment Anything Models)用于自动/交互式关键帧分割,以及 DeAOT(Decoupling features in Associating Objects with Transformers)(NeurIPS2022)用于高效的多目标跟踪和传播。SAM-Track 管道实现了 SAM 的动态自动检测和分割新物体,而 DeAOT 负责跟踪所有识别到的物体。

    2、项目特点

    自动/交互式分割:项目中的 SAM(Segment Anything Models)算法提供了自动和交互式关键帧分割的功能。通过 SAM,用户可以选择使用自动分割算法或与算法进行交互,以实现对视频中任意对象的精确分割。这种灵活性使得该项目适用于不同需求的应用场景。

    高效多目标跟踪:Segment-and-Track-Anything 还引入了 DeAOT 算法,用于实现高效的多目标跟踪和传播。DeAOT 利用先进的跟踪技术,能够准确地跟踪视频中的多个对象,并支持对象之间的传播和关联。这使得项目在处理复杂场景和多目标跟踪任务时表现出色。

    独立和开放性:该项目是一个独立的开源项目,可以直接访问和使用。它提供了丰富的文档和示例代码,帮助用户快速上手并进行定制开发。同时,项目欢迎社区的贡献和扩展,这使得用户能够与其他研究者和开发者共享经验和成果。

    应用广泛性:Segment-and-Track-Anything 的分割和跟踪功能可以应用于各种视频分析任务,包括视频监控、智能交通、行为分析等。它为研究者和开发者提供了一个强大的工具,用于处理和分析具有复杂动态场景的视频数据。
    在这里插入图片描述

    三、项目部署

    项目地址:https://github.com/z-x-yang/Segment-and-Track-Anything

    1.部署环境

    我这里测试部署的系统win 10, cuda 11.8,cudnn 8.5,GPU是RTX 3060, 8G显存,使用conda创建虚拟环境。
    创建并激活一个虚拟环境:

    conda create -n sta python==3.10
    activate sta
    
    • 1
    • 2

    下载项目:

    git clone https://github.com/z-x-yang/Segment-and-Track-Anything.git
    cd Segment-and-Track-Anything
    pip install gradio
    pip install scikit-image
    
    • 1
    • 2
    • 3
    • 4

    因为要使用GPU,这里单独安装pytorch

    conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
    
    • 1

    因为项目的依赖要使用sh脚本进行安装,win下不支持bash,所以要单独安装m2-base:

    conda install m2-base
    
    • 1

    安装项目依赖:

    bash script/install.sh
    
    • 1

    当出现下面提示代表安装成功。
    在这里插入图片描述
    GroundingDINO可能会安装不成功,可以直接从源码安装:

    git clone https://github.com/IDEA-Research/GroundingDINO.git
    cd GroundingDINO/
    pip install -e .
    cd ..
    
    • 1
    • 2
    • 3
    • 4

    下载所需模型:

     bash script/download_ckpt.sh
    
    • 1

    如果模型下载不成功,也可以手动复制这个地址把模型下载了放到指定目录目录.

    2.运行项目

    python app.py
    
    • 1

    然后打开http://127.0.0.1:7860/
    在这里插入图片描述
    导入一个视频,然后只追踪其中一个人,效果如下:
    在这里插入图片描述

    视频目标追踪:

    目标分割与目标追踪

    3.分割与追踪处理代码

    import sys
    sys.path.append("..")
    sys.path.append("./sam")
    from sam.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
    from aot_tracker import get_aot
    import numpy as np
    from tool.segmentor import Segmentor
    from tool.detector import Detector
    from tool.transfer_tools import draw_outline, draw_points
    import cv2
    from seg_track_anything import draw_mask
    
    
    class SegTracker():
        def __init__(self,segtracker_args, sam_args, aot_args) -> None:
            """
             Initialize SAM and AOT.
            """
            self.sam = Segmentor(sam_args)
            self.tracker = get_aot(aot_args)
            self.detector = Detector(self.sam.device)
            self.sam_gap = segtracker_args['sam_gap']
            self.min_area = segtracker_args['min_area']
            self.max_obj_num = segtracker_args['max_obj_num']
            self.min_new_obj_iou = segtracker_args['min_new_obj_iou']
            self.reference_objs_list = []
            self.object_idx = 1
            self.curr_idx = 1
            self.origin_merged_mask = None  # init by segment-everything or update
            self.first_frame_mask = None
    
            # debug
            self.everything_points = []
            self.everything_labels = []
            print("SegTracker has been initialized")
    
        def seg(self,frame):
            '''
            Arguments:
                frame: numpy array (h,w,3)
            Return:
                origin_merged_mask: numpy array (h,w)
            '''
            frame = frame[:, :, ::-1]
            anns = self.sam.everything_generator.generate(frame)
    
            # anns is a list recording all predictions in an image
            if len(anns) == 0:
                return
            # merge all predictions into one mask (h,w)
            # note that the merged mask may lost some objects due to the overlapping
            self.origin_merged_mask = np.zeros(anns[0]['segmentation'].shape,dtype=np.uint8)
            idx = 1
            for ann in anns:
                if ann['area'] > self.min_area:
                    m = ann['segmentation']
                    self.origin_merged_mask[m==1] = idx
                    idx += 1
                    self.everything_points.append(ann["point_coords"][0])
                    self.everything_labels.append(1)
    
            obj_ids = np.unique(self.origin_merged_mask)
            obj_ids = obj_ids[obj_ids!=0]
    
            self.object_idx = 1
            for id in obj_ids:
                if np.sum(self.origin_merged_mask==id) < self.min_area or self.object_idx > self.max_obj_num:
                    self.origin_merged_mask[self.origin_merged_mask==id] = 0
                else:
                    self.origin_merged_mask[self.origin_merged_mask==id] = self.object_idx
                    self.object_idx += 1
    
            self.first_frame_mask = self.origin_merged_mask
            return self.origin_merged_mask
    
        def update_origin_merged_mask(self, updated_merged_mask):
            self.origin_merged_mask = updated_merged_mask
            # obj_ids = np.unique(updated_merged_mask)
            # obj_ids = obj_ids[obj_ids!=0]
            # self.object_idx = int(max(obj_ids)) + 1
    
        def reset_origin_merged_mask(self, mask, id):
            self.origin_merged_mask = mask
            self.curr_idx = id
    
        def add_reference(self,frame,mask,frame_step=0):
            '''
            Add objects in a mask for tracking.
            Arguments:
                frame: numpy array (h,w,3)
                mask: numpy array (h,w)
            '''
            self.reference_objs_list.append(np.unique(mask))
            self.curr_idx = self.get_obj_num()
            self.tracker.add_reference_frame(frame,mask, self.curr_idx, frame_step)
    
        def track(self,frame,update_memory=False):
            '''
            Track all known objects.
            Arguments:
                frame: numpy array (h,w,3)
            Return:
                origin_merged_mask: numpy array (h,w)
            '''
            pred_mask = self.tracker.track(frame)
            if update_memory:
                self.tracker.update_memory(pred_mask)
            return pred_mask.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.uint8)
        
        def get_tracking_objs(self):
            objs = set()
            for ref in self.reference_objs_list:
                objs.update(set(ref))
            objs = list(sorted(list(objs)))
            objs = [i for i in objs if i!=0]
            return objs
        
        def get_obj_num(self):
            objs = self.get_tracking_objs()
            if len(objs) == 0: return 0
            return int(max(objs))
    
        def find_new_objs(self, track_mask, seg_mask):
            '''
            Compare tracked results from AOT with segmented results from SAM. Select objects from background if they are not tracked.
            Arguments:
                track_mask: numpy array (h,w)
                seg_mask: numpy array (h,w)
            Return:
                new_obj_mask: numpy array (h,w)
            '''
            new_obj_mask = (track_mask==0) * seg_mask
            new_obj_ids = np.unique(new_obj_mask)
            new_obj_ids = new_obj_ids[new_obj_ids!=0]
            # obj_num = self.get_obj_num() + 1
            obj_num = self.curr_idx
            for idx in new_obj_ids:
                new_obj_area = np.sum(new_obj_mask==idx)
                obj_area = np.sum(seg_mask==idx)
                if new_obj_area/obj_area < self.min_new_obj_iou or new_obj_area < self.min_area\
                    or obj_num > self.max_obj_num:
                    new_obj_mask[new_obj_mask==idx] = 0
                else:
                    new_obj_mask[new_obj_mask==idx] = obj_num
                    obj_num += 1
            return new_obj_mask
            
        def restart_tracker(self):
            self.tracker.restart()
    
        def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,):
            ''''
            Use bbox-prompt to get mask
            Parameters:
                origin_frame: H, W, C
                bbox: [[x0, y0], [x1, y1]]
            Return:
                refined_merged_mask: numpy array (h, w)
                masked_frame: numpy array (h, w, c)
            '''
            # get interactive_mask
            interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
            refined_merged_mask = self.add_mask(interactive_mask)
    
            # draw mask
            masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
    
            # draw bbox
            masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))
    
            return refined_merged_mask, masked_frame
    
        def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True):
            '''
            Use point-prompt to get mask
            Parameters:
                origin_frame: H, W, C
                coords: nd.array [[x, y]]
                modes: nd.array [[1]]
            Return:
                refined_merged_mask: numpy array (h, w)
                masked_frame: numpy array (h, w, c)
            '''
            # get interactive_mask
            interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask)
    
            refined_merged_mask = self.add_mask(interactive_mask)
    
            # draw mask
            masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
    
            # draw points
            # self.everything_labels = np.array(self.everything_labels).astype(np.int64)
            # self.everything_points = np.array(self.everything_points).astype(np.int64)
    
            masked_frame = draw_points(coords, modes, masked_frame)
    
            # draw outline
            masked_frame = draw_outline(interactive_mask, masked_frame)
    
            return refined_merged_mask, masked_frame
    
        def add_mask(self, interactive_mask: np.ndarray):
            '''
            Merge interactive mask with self.origin_merged_mask
            Parameters:
                interactive_mask: numpy array (h, w)
            Return:
                refined_merged_mask: numpy array (h, w)
            '''
            if self.origin_merged_mask is None:
                self.origin_merged_mask = np.zeros(interactive_mask.shape,dtype=np.uint8)
    
            refined_merged_mask = self.origin_merged_mask.copy()
            refined_merged_mask[interactive_mask > 0] = self.curr_idx
    
            return refined_merged_mask
        
        def detect_and_seg(self, origin_frame: np.ndarray, grounding_caption, box_threshold, text_threshold, box_size_threshold=1, reset_image=False):
            '''
            Using Grounding-DINO to detect object acc Text-prompts
            Retrun:
                refined_merged_mask: numpy array (h, w)
                annotated_frame: numpy array (h, w, 3)
            '''
            # backup id and origin-merged-mask
            bc_id = self.curr_idx
            bc_mask = self.origin_merged_mask
    
            # get annotated_frame and boxes
            annotated_frame, boxes = self.detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold)
            for i in range(len(boxes)):
                bbox = boxes[i]
                if (bbox[1][0] - bbox[0][0]) * (bbox[1][1] - bbox[0][1]) > annotated_frame.shape[0] * annotated_frame.shape[1] * box_size_threshold:
                    continue
                interactive_mask = self.sam.segment_with_box(origin_frame, bbox, reset_image)[0]
                refined_merged_mask = self.add_mask(interactive_mask)
                self.update_origin_merged_mask(refined_merged_mask)
                self.curr_idx += 1
    
            # reset origin_mask
            self.reset_origin_merged_mask(bc_mask, bc_id)
    
            return refined_merged_mask, annotated_frame
    
    if __name__ == '__main__':
        from model_args import segtracker_args,sam_args,aot_args
    
        Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
        
        # ------------------ detect test ----------------------
        
        origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png')
        origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB)
        grounding_caption = "swan.water"
        box_threshold = 0.25
        text_threshold = 0.25
    
        predicted_mask, annotated_frame = Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
        masked_frame = draw_mask(annotated_frame, predicted_mask)
        origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_RGB2BGR)
    
        cv2.imwrite('./debug/masked_frame.png', masked_frame)
        cv2.imwrite('./debug/x.png', annotated_frame)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264

    四、 报错

    1.下载模型问题

    requests.exceptions.SSLError: (MaxRetryError(“HTTPSConnectionPool(host=‘huggingface.co’, port=443): Max retries exceeded with url: /bert-base-uncased/resolve/main/tokenizer_config.json (Caused by SSLError(SSLEOFError(8, ‘EOF occurred in violation of protocol (_ssl.c:997)’)))”), ‘(Request ID: d4f21f96-45fd-47a1-9afb-b7e4260a6f3b)’)

    https://huggingface.co/bert-base-uncased/tree/main

    在这里插入图片描述
    可以手动从这里下载模型,然后放到指定的目录:
    在这里插入图片描述

    2. imageio版本问题

    TypeError: The keyword fps is no longer supported. Use duration(in ms) instead, e.g. fps=50 == duration=20 (1000 * 1/50).

    pip uninstall imageio
    pip install imageio==2.23.0
    
    • 1
    • 2
  • 相关阅读:
    mysql获取近7天,7周,7月,7年日期,根据当前时间获取近7天,7周,7月,7年日期
    Unity:UI自动布局与多级菜单
    RFID智能锁控系统在物流安全运输中的应用与效益分析
    MySQL中对于事务的理解
    归并排序及其非递归实现
    基于SSM和layUI的汽车租赁系统设计
    Window系统安装JDK8与Maven
    科普:如何应用视觉显著性模型优化远控编码算法?
    QT day1
    一起学docker系列之二深入理解Docker:基本概念、工作原理与架构
  • 原文地址:https://blog.csdn.net/matt45m/article/details/133110802