• [文档] torch.distributions.Categorical


    这段时间看见分布式、并行之类的词语就害怕,结果这个是 distributions,分布,就是一些表征分布的函数们

    签名:

    torch.distributions.categorical.Categorical(probs=None, 
                                                logits=None, 
                                                validate_args=None)
    
    • 1
    • 2
    • 3

    Creates a categorical distribution parameterized by either probs or logits (but not both).
    创建一个离散的类别分布,参数由 probslogits, 二者其一指定

    If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.

    If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.


    如果 probs 是个一维的长度为K的张量,则每个元素是索引对应的类别的相对概率,将通过该概率进行采样
    probs 是N维张量,前N-1维只会被视为响应的 Batch, 这个不好翻译,直接看例子把

    第一个例子(官方Demo):

    >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
    >>> m.sample()  # 等概率的返回 0, 1, 2, 3
    
    • 1
    • 2

    第二个例子:

    probs = torch.FloatTensor([[0.05,0.1,0.85],[0.05,0.05,0.9]])
     
    dist = Categorical(probs)
    print(dist)
    # Categorical(probs: torch.Size([2, 3]))
     
    index = dist.sample()
    print(index.numpy())
    # 很大概率会是 [2 2]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    如果 probs 变量的 -1 维度上求和不为1,Categorical 内部也会帮你归一化,然后再 sample

    第二个例子:

    >>> probs = torch.FloatTensor([[[0.05,0.1,0.85],[0.05,0.05,0.9]]])
    >>> probs.shape
    torch.Size([1, 2, 3])
    
    >>> dist = Categorical(probs)
    >>> index = dist.sample()
    >>> index.shape
    torch.Size([1, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    上边这个例子,用来理解这句话:

    If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.

    probs.shape[X, Y, N, s]index.shape[X, Y, N]Categorical 只在最后一维上采样


    最后补一句,Categorical对象,有.entropy() 方法,用来计算熵

    >>> probs = torch.FloatTensor([[[0.05,0.1,0.85],[0.05,0.05,0.9]]])
    >>> dist = Categorical(probs)
    >>> dist.entropy()
    tensor([[0.5182, 0.3944]])
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    操作系统MIT6.S081:P6->Page faults
    【附源码】Python计算机毕业设计烹饪课程预约系统
    Swift基础语法 - 流程控制
    数字三角形-蓝桥杯
    Vue模板语法(下)
    C++ stack,queue,priority_queue容器适配器模拟实现
    漏刻有时数据可视化Echarts组件开发(31):geomap伪3D配置示例
    抖音API接口大全
    RabbitMQ核心总结
    QTabBar实验
  • 原文地址:https://blog.csdn.net/HaoZiHuang/article/details/126356668