• 利用torch.nn实现softmax回归Fashion-MNIST数据集上进行训练和测试


    利用torch.nn实现softmax回归Fashion-MNIST数据集上进行训练和测试:

    1)(2)(3)同上

    4)构建模型

    1. num_inputs = 784  
    2. num_outputs = 10  10  
    3.   
    4. 构建模型  
    5. class softmaxnet(torch.nn.Module):  
    6.     def __init__(self, n_features, n_labels):  
    7.         super(softmaxnet, self).__init__()  
    8.         self.linear = torch.nn.Linear(n_features, n_labels)  
    9.   
    10.     def softmax(self, X):  # softmax计算  
    11.         X_exp = X.exp()  对每个元素做指数运算  
    12.         partition = X_exp.sum(dim=1, keepdim=True)  求列和,即对同行元素求和 n*1  
    13.         return X_exp / partition  # broadcast  
    14.   
    15.     def forward(self, x):  
    16.         x_ = x.view((-1, num_inputs))  
    17.         y_ = self.linear(x_)  
    18.         y_hat = self.softmax(y_)  
    19.         return y_hat  

    5)损失函数和优化算法

    1. #损失函数和优化方法  
    2. net = softmaxnet(num_inputs, num_outputs)  
    3. lr = 0.3  
    4. loss = torch.nn.CrossEntropyLoss()  
    5. optimizer = optim.SGD(net.parameters(), lr=lr)  

    6)构建测试集准确率函数

    1. #测试集的准确度与损失  
    2. def get_test_info(data_iter, net):  
    3.     right_count, all_count = 0.0, 0  
    4.     for x, y in data_iter:  
    5.         y_ = net(x)  
    6.         l = loss(y_, y)  
    7.         right_count += (y_.argmax(dim=1)==y).sum().item()  
    8.         all_count += y.shape[0]  
    9.     return right_count/all_count, l.item() 

    7)开始优化并分别输出训练集和测试集的损失和准确率

    1. num_epoch = 20  
    2.   
    3. for epoch in range(num_epoch):  
    4.     train_r_num, train_all_num = 0.0, 0  
    5.     for X, y in train_iter:  
    6.         y_ = net(X)  
    7.         l = loss(y_, y)  
    8.         l.backward()  
    9.         optimizer.step()  
    10.         optimizer.zero_grad()  
    11.         train_r_num += (y_.argmax(dim=1) == y).sum().item()  
    12.         train_all_num += y.shape[0]  
    13.     test_acc, test_ave_loss = get_test_info(test_iter, net)  
    14.     print('epoch %d, train loss %.4f, train acc %.3f' % (epoch+1, l.item(), train_r_num/train_all_num))  
    15.     print('          test loss %.4f, test acc %.3f' % (test_ave_loss, test_acc))  
  • 相关阅读:
    Day41—— 343. 整数拆分 96.不同的二叉搜索树 (动规)
    文件名批量重命名与翻译的实用指南
    自定义事件之C#设计笔记(十)
    八一书《乡村振兴战略下传统村落文化旅游设计》许少辉瑞博士生辉少许——2023学生开学季许多少年辉光三农
    Text-to-SQL小白入门(五)开源最强代码大模型Code Llama
    Linux操作系统基础 – 正则表达式快速入门
    Go语言适用场景
    SpringCloud - Spring Cloud Alibaba 之 Skywalking 分布式链路跟踪;下载安装,应用(十二)
    什么是自动化测试?如何开展自动化测试你需要知道这些点
    刷题篇(一)
  • 原文地址:https://blog.csdn.net/ccyyll1/article/details/126020665