测试时增强(Test-Time Augmentation,TTA)是一种在深度学习模型的测试阶段应用数据增强的技术手段。它是通过对测试样本进行多次随机变换或扰动,产生多个增强的样本,并使用这些样本进行预测的多数投票或平均来得出最终预测结果。
为了直观理解TTA执行的过程,这里我绘制了流程示意图如下所示:

TTA的过程如下:
数据增强:
多次预测:
预测结果集成:
接下来针对性地对比分析下使用TTA带来的优点和缺点:
优点:
缺点:
TTA是一种常用的技术手段,通过应用数据增强和集成预测结果,可以提高深度学习模型在测试阶段的性能和鲁棒性。然而,TTA的应用需要平衡计算开销和预测准确性,并谨慎处理可能导致模型过拟合的问题。根据具体任务和需求,可以灵活选择合适的增强操作和集成策略来使用TTA。
下面是demo代码实现,如下所示:
- import numpy as np
- import torch
- import torchvision.transforms as transforms
-
- def test_time_augmentation(model, image, n_augmentations):
- # 定义数据增强的变换
- transform = transforms.Compose([
- transforms.ToTensor(),
- # 在此添加你需要的任何其他数据增强操作
- ])
-
- # 存储多次预测结果的列表
- predictions = []
-
- # 对图像应用多次增强和预测
- for _ in range(n_augmentations):
- augmented_image = transform(image)
- augmented_image = augmented_image.unsqueeze(0) # 增加一个维度作为批次
- with torch.no_grad():
- # 切换模型为评估模式,确保不执行梯度计算
- model.eval()
- # 使用增强的图像进行预测
- output = model(augmented_image)
- _, predicted = torch.max(output.data, 1)
- predictions.append(predicted.item())
-
- # 执行多数投票并返回最终预测结果
- final_prediction = np.bincount(predictions).argmax()
-
- return final_prediction
在前文鸟类细粒度识别项目实验中测试发现,应用TTA技术后,对应的评估指标上有明显的涨点,但是很明显地可以发现:在整个测试过程中资源消耗增加明显,且耗时显著增长,这也是TTA无法避免的劣势,在对精度要求较高的场景下可以有限考虑引入TTA,但是对于计算时耗要求较高的场景则不推荐使用TTA。
开源社区里面也有一些优秀的实现,这里推荐一个,地址在这里,如下所示:

目前有将近1k的star量,还是蛮不错的。
安装方法如下所示:
- pip安装:
- pip install ttach
-
-
- 源码安装:
- pip install git+https://github.com/qubvel/ttach
- Input
- | # input batch of images
- / / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
- | | | | | | | # pass augmented batches through model
- | | | | | | | # reverse transformations for each batch of masks/labels
- \ \ \ / / / # merge predictions (mean, max, gmean, etc.)
- | # output batch of masks/labels
- Output
目前支持分割、分类、关键点检测三种任务,实例使用如下所示:
- Segmentation model wrapping [docstring]:
- import ttach as tta
- tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
-
-
- Classification model wrapping [docstring]:
- tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
-
-
- Keypoints model wrapping [docstring]:
- tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
data transforms 实例实现如下所示:
- # defined 2 * 2 * 3 * 3 = 36 augmentations !
- transforms = tta.Compose(
- [
- tta.HorizontalFlip(),
- tta.Rotate90(angles=[0, 180]),
- tta.Scale(scales=[1, 2, 4]),
- tta.Multiply(factors=[0.9, 1, 1.1]),
- ]
- )
-
- tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)实现如下所示:
- # Example how to process ONE batch on images with TTA
- # Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
-
- for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
-
- # augment image
- augmented_image = transformer.augment_image(image)
-
- # pass to model
- model_output = model(augmented_image, another_input_data)
-
- # reverse augmentation for mask and label
- deaug_mask = transformer.deaugment_mask(model_output['mask'])
- deaug_label = transformer.deaugment_label(model_output['label'])
-
- # save results
- labels.append(deaug_mask)
- masks.append(deaug_label)
-
- # reduce results as you want, e.g mean/max/min
- label = mean(labels)
- mask = mean(masks)
Transforms详情如下所示:
| Transform | Parameters | Values |
|---|---|---|
| HorizontalFlip | - | - |
| VerticalFlip | - | - |
| Rotate90 | angles | List[0, 90, 180, 270] |
| Scale | scales interpolation | List[float] "nearest"/"linear" |
| Resize | sizes original_size interpolation | List[Tuple[int, int]] Tuple[int,int] "nearest"/"linear" |
| Add | values | List[float] |
| Multiply | factors | List[float] |
| FiveCrops | crop_height crop_width | int int |
支持的结果融合方法如下:
- mean
- gmean (geometric mean)
- sum
- max
- min
- tsharpen (temperature sharpen with t=0.5)