• 利用torch.nn实现softmax回归Fashion-MNIST数据集上进行训练和测试


    利用torch.nn实现softmax回归Fashion-MNIST数据集上进行训练和测试:

    1)(2)(3)同上

    4)构建模型

    1. num_inputs = 784  
    2. num_outputs = 10  10  
    3.   
    4. 构建模型  
    5. class softmaxnet(torch.nn.Module):  
    6.     def __init__(self, n_features, n_labels):  
    7.         super(softmaxnet, self).__init__()  
    8.         self.linear = torch.nn.Linear(n_features, n_labels)  
    9.   
    10.     def softmax(self, X):  # softmax计算  
    11.         X_exp = X.exp()  对每个元素做指数运算  
    12.         partition = X_exp.sum(dim=1, keepdim=True)  求列和,即对同行元素求和 n*1  
    13.         return X_exp / partition  # broadcast  
    14.   
    15.     def forward(self, x):  
    16.         x_ = x.view((-1, num_inputs))  
    17.         y_ = self.linear(x_)  
    18.         y_hat = self.softmax(y_)  
    19.         return y_hat  

    5)损失函数和优化算法

    1. #损失函数和优化方法  
    2. net = softmaxnet(num_inputs, num_outputs)  
    3. lr = 0.3  
    4. loss = torch.nn.CrossEntropyLoss()  
    5. optimizer = optim.SGD(net.parameters(), lr=lr)  

    6)构建测试集准确率函数

    1. #测试集的准确度与损失  
    2. def get_test_info(data_iter, net):  
    3.     right_count, all_count = 0.0, 0  
    4.     for x, y in data_iter:  
    5.         y_ = net(x)  
    6.         l = loss(y_, y)  
    7.         right_count += (y_.argmax(dim=1)==y).sum().item()  
    8.         all_count += y.shape[0]  
    9.     return right_count/all_count, l.item() 

    7)开始优化并分别输出训练集和测试集的损失和准确率

    1. num_epoch = 20  
    2.   
    3. for epoch in range(num_epoch):  
    4.     train_r_num, train_all_num = 0.0, 0  
    5.     for X, y in train_iter:  
    6.         y_ = net(X)  
    7.         l = loss(y_, y)  
    8.         l.backward()  
    9.         optimizer.step()  
    10.         optimizer.zero_grad()  
    11.         train_r_num += (y_.argmax(dim=1) == y).sum().item()  
    12.         train_all_num += y.shape[0]  
    13.     test_acc, test_ave_loss = get_test_info(test_iter, net)  
    14.     print('epoch %d, train loss %.4f, train acc %.3f' % (epoch+1, l.item(), train_r_num/train_all_num))  
    15.     print('          test loss %.4f, test acc %.3f' % (test_ave_loss, test_acc))  
  • 相关阅读:
    机器学习笔记之指数族分布——指数族分布介绍
    Linux:安装MySQL服务(非docker方式)
    @Configuration注解Full模式和Lite模式
    SpringBoot SpringBoot 基础篇 4 基于 SpringBoot 的SSMP 整合案例 4.6 分页功能
    ArrayList与顺序表
    C++并发与多线程笔记六:单例模式下的数据共享
    kotlin基础知识
    Flood Fill 算法
    开源利器:it-tools 项目介绍
    初学Java,遇错就懵,这类问题到底怎么处理呢?!
  • 原文地址:https://blog.csdn.net/ccyyll1/article/details/126020665