• PyTorch中特殊函数梯度的计算


    PyTorch中特殊函数梯度的计算

    普通函数

    对于简单的多元函数,对自变量求梯度很容易,例如:
    f ( x , y ) = x 2 + y 2 f(x,y)=x^2+y^2 f(x,y)=x2+y2
    则有:
    { ∇ x f ( x , y ) = 2 x ∇ y f ( x , y ) = 2 y \left\{ xf(x,y)=2xyf(x,y)=2y

    \right . {xf(x,y)yf(x,y)=2x=2y

    import torch
    x = torch.tensor([1, 1, 1.0], requires_grad=True)
    y = torch.tensor([2, 2, 2.0], requires_grad=True)
    z = torch.pow(x, 2) + torch.pow(y, 2)
    z.sum().backward()
    x.grad, y.grad
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    (tensor([2., 2., 2.]), tensor([4., 4., 4.]))
    
    • 1
    特殊函数
    1. Max函数

    一般是求几个输入元素的最大值,如何计算梯度呢?
    f ( x 0 , x 1 , … , x n ) = max ⁡ ( x 0 , x 1 , … , x n ) f(x_0,x_1,\ldots,x_n)=\max(x_0,x_1,\ldots,x_n) f(x0,x1,,xn)=max(x0,x1,,xn)

    1. 在数值上求出最大值 a a a

    2. 对函数进行变换
      f ( x 0 , x 1 , … , x n , a ) = max ⁡ ( x 0 , x 1 , … , x n , a ) = { a i f   x < a x i f   x = a f(x_0,x_1,\ldots,x_n,a)=\max(x_0,x_1,\ldots,x_n,a)= \left\{ aif x<axif x=a

      \right. f(x0,x1,,xn,a)=max(x0,x1,,xn,a)={aif x<axif x=a

    3. 变换后就可以求梯度了
      ∇ x f ( x , a ) = { 0 i f   x < a 1 i f   x = a \nabla_x f(x,a)= \left\{ 0if x<a1if x=a

      \right . xf(x,a)={0if x<a1if x=a

    在PyTorch中,如果存在多个相等的最大值,那么它们均分"1":

    import torch
    
    x = torch.tensor([1, 2, 3, 4, 4, 0.], requires_grad=True)
    y = torch.max(x)
    y.backward()
    x.grad
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    tensor([0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000])
    
    • 1
    2. Clip函数

    在数据落在一定范围外时,与输入无关
    f ( x ) = { x i f   a < x < b a i f   x < a b i f   x > b f(x)= \left\{ xif a<x<baif x<abif x>b

    \right. f(x)= xif a<x<baif x<abif x>b

    import torch
    
    x = torch.tensor([1, 2, 3, 4, 5, 6, 7.0], requires_grad=True)
    y = torch.clip(x, 1.5, 5.5)
    y.sum().backward()
    x.grad
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    tensor([0., 1., 1., 1., 1., 0., 0.])
    
    • 1
  • 相关阅读:
    (Python入门)函数
    基于HTML和JavaScript的会议室预约管理系统
    kotlin 类
    通过机器视觉对硬盘容器上盖的字符进行视觉识别,判断是否混料
    Solidity智能合约开发 — 3.4-抽象智能合约和接口
    本地开机启动jar
    调整屏幕的宽高比
    [C++]多态
    C语言日记 32 类的对象,this指针
    SpringBoot:SpringBoot集成Druid监控慢SQL
  • 原文地址:https://blog.csdn.net/qq_51352578/article/details/132689928