• deepfm内容理解


    对于CTR问题,被证明的最有效的提升任务表现的策略是特征组合(Feature Interaction);

    两个问题:

    如何更好地学习特征组合,进而更加精确地描述数据的特点;

    如何更高效的学习特征组合。

    DNN局限 :当我们使用DNN网络解决推荐问题的时候存在网络参数过于庞大的问题,这是因为在进行特征处理的时候我们需要使用one-hot编码来处理离散特征,这会导致输入的维度猛增。

    为了解决DNN参数量过大的局限性,可以采用非常经典的Field思想,将OneHot特征转换为Dense Vector,通过增加全连接层就可以实现高阶的特征组合。

    黑色的线 和 红色的线 进行concat

    self定义 

    deep_features = deep_features
    fm_features = fm_features  #稀疏的特征
    deep_dims = sum([fea.embed_dim for fea in deep_features])  #8
    fm_dims = sum([fea.embed_dim for fea in fm_features])  #368   = 23*16           #稀疏的特征embedding化
    linear = LR(fm_dims)  # 1-odrder interaction   低阶信息   (fc): Linear(in_features=368, out_features=1, bias=True)
    fm = FM(reduce_sum=True)  # 2-odrder interaction    #FM将一阶特征和二阶特征cancat
    embedding = EmbeddingLayer(deep_features + fm_features)
    mlp = MLP(deep_dims, **mlp_params)

     forward


    input_deep = embedding(x, deep_features, squeeze_dim=True)  #[batch_size, deep_dims]    torch.Size([10, 8])
    input_fm = embedding(x, fm_features, squeeze_dim=False)  #[batch_size, num_fields, embed_dim]   torch.Size([10, 23, 16])
    y_linear = linear(input_fm.flatten(start_dim=1))  #torch.Size([10, 1])  对应的稀疏特征 经过线性层变为1
    y_fm = fm(input_fm)  #torch.Size([10, 1])    #对稀疏特征做一阶 二阶处理 
    y_deep = mlp(input_deep)  #[batch_size, 1]  #torch.Size([10, 1])
    y = y_linear + y_fm + y_deep          
    # return torch.sigmoid(y.squeeze(1))

    定义的一些函数: 

    import torch.nn as nn
    class LR(nn.Module):
        """Logistic Regression Module. It is the one Non-linear 
        transformation for input feature.

        Args:
            input_dim (int): input size of Linear module.
            sigmoid (bool): whether to add sigmoid function before output.

        Shape:
            - Input: `(batch_size, input_dim)`
            - Output: `(batch_size, 1)`
        """

        def __init__(self, input_dim, sigmoid=False):
            super().__init__()
            self.sigmoid = sigmoid
            self.fc = nn.Linear(input_dim, 1, bias=True)

        def forward(self, x):
            if self.sigmoid:
                return torch.sigmoid(self.fc(x))
            else:
                return self.fc(x)
            

    class FM(nn.Module):
        """The Factorization Machine module, mentioned in the `DeepFM paper
        `. It is used to learn 2nd-order 
        feature interactions.

        Args:
            reduce_sum (bool): whether to sum in embed_dim (default = `True`).

        Shape:
            - Input: `(batch_size, num_features, embed_dim)`
            - Output: `(batch_size, 1)`` or ``(batch_size, embed_dim)`
        """

        def __init__(self, reduce_sum=True):
            super().__init__()
            self.reduce_sum = reduce_sum

        def forward(self, x):
            square_of_sum = torch.sum(x, dim=1)**2
            sum_of_square = torch.sum(x**2, dim=1)
            ix = square_of_sum - sum_of_square
            if self.reduce_sum:
                ix = torch.sum(ix, dim=1, keepdim=True)
            return 0.5 * ix

    参考资料:

    推荐系统遇上深度学习(三)--DeepFM模型理论和实践 - 简书 (jianshu.com)

    DeepFM (datawhalechina.github.io)

  • 相关阅读:
    Java异常机制
    开放领域问答机器人1
    【数据结构】优先级队列
    这份阿里巴巴 Java 架构六大专题面试宝典值得你刷一刷
    腾讯待办是什么?关停之后如何继续提醒待办事项?
    IDEA clion + vim =neovim
    分享25个JSP源码,总有一款适合您
    API响应状态
    sql的模糊查询
    Java 加载、编辑和保存WPS表格文件(.et/.ett)
  • 原文地址:https://blog.csdn.net/Ajdidfj/article/details/132746930