torch.clamp(x, min, max)
该函数是用来做截断处理的,通常被使用在需要比较大小的地方。
该函数的截断规则为:
y i = { m i n , x i < m i n x i , m i n < = x i < = m a x m a x , x i > m a x y_i={min,xi<minxi,min<=xi<=maxmax,xi>max yi=⎩⎪⎨⎪⎧min,xi<minxi,min<=xi<=maxmax,xi>max
示例:
>>>x = torch.arange(12)
>>>print(x)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
>>>torch.clamp(x, 2, 10)
tensor([ 2, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10])
我们可以看到小于2的位置的元素全部被替换为最小值2,大于10的位置全部被替换为最大值10,中间位置的数值保持不变。