• [文档] torch.distributions.Categorical


    这段时间看见分布式、并行之类的词语就害怕,结果这个是 distributions,分布,就是一些表征分布的函数们

    签名:

    torch.distributions.categorical.Categorical(probs=None, 
                                                logits=None, 
                                                validate_args=None)
    
    • 1
    • 2
    • 3

    Creates a categorical distribution parameterized by either probs or logits (but not both).
    创建一个离散的类别分布,参数由 probslogits, 二者其一指定

    If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.

    If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.


    如果 probs 是个一维的长度为K的张量,则每个元素是索引对应的类别的相对概率,将通过该概率进行采样
    probs 是N维张量,前N-1维只会被视为响应的 Batch, 这个不好翻译,直接看例子把

    第一个例子(官方Demo):

    >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
    >>> m.sample()  # 等概率的返回 0, 1, 2, 3
    
    • 1
    • 2

    第二个例子:

    probs = torch.FloatTensor([[0.05,0.1,0.85],[0.05,0.05,0.9]])
     
    dist = Categorical(probs)
    print(dist)
    # Categorical(probs: torch.Size([2, 3]))
     
    index = dist.sample()
    print(index.numpy())
    # 很大概率会是 [2 2]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    如果 probs 变量的 -1 维度上求和不为1,Categorical 内部也会帮你归一化,然后再 sample

    第二个例子:

    >>> probs = torch.FloatTensor([[[0.05,0.1,0.85],[0.05,0.05,0.9]]])
    >>> probs.shape
    torch.Size([1, 2, 3])
    
    >>> dist = Categorical(probs)
    >>> index = dist.sample()
    >>> index.shape
    torch.Size([1, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    上边这个例子,用来理解这句话:

    If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.

    probs.shape[X, Y, N, s]index.shape[X, Y, N]Categorical 只在最后一维上采样


    最后补一句,Categorical对象,有.entropy() 方法,用来计算熵

    >>> probs = torch.FloatTensor([[[0.05,0.1,0.85],[0.05,0.05,0.9]]])
    >>> dist = Categorical(probs)
    >>> dist.entropy()
    tensor([[0.5182, 0.3944]])
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    Win11如何更改默认下载路径?Win11更改默认下载路径的方法
    Rocketmq的集群搭建
    SpringMVC如何实现重定向和转发呢?
    目标检测YOLO实战应用案例100讲-SAR图像多尺度舰船目标检测(续)
    Ubuntu - 查看、开启、关闭和永久关闭防火墙
    【Vue3+TS】Axios拦截器封装及跨域 [cors] 解决方案
    NC61 两数之和
    nprogress进度条插件的使用
    2022年卡塔尔世界杯黑科技盘点
    Three.js真实相机畸变效果模拟
  • 原文地址:https://blog.csdn.net/HaoZiHuang/article/details/126356668