码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • pytorch回炉再造笔记--python类中getitem的用法


    目录

    1-- 类中__getitem__的作用

    2-- 实例

    3-- 结合pytorch封装并读取batch数据

    4-- 参考


    1-- 类中__getitem__的作用

    当一个python类中定义了__getitem__函数,则其实例对象能够通过下标来进行索引数据。

    2-- 实例

    代码:

    1. import numpy as np
    2. # 创建类
    3. class Example():
    4. def __getitem__(self, index):
    5. data = np.array([[1,2,3], [4,5,6], [7,8,9]])
    6. return data[index]
    7. # 使用Example类实例对象example1
    8. example1 = Example()
    9. # 索引访问数据
    10. print('example1[0][0]:', example1[0][0])
    11. print('example1[0]:', example1[0])
    12. # 切片访问数据
    13. print('example1[0:2]:\n', example1[0:2])

    输出:

    1. example1[0][0]: 1
    2. example1[0]: [1 2 3]
    3. example1[0:2]:
    4. [[1 2 3]
    5. [4 5 6]]

    3-- 结合pytorch封装并读取batch数据

    代码:

    1. import torch
    2. import numpy as np
    3. from torch.utils.data import Dataset
    4. # 创建MyDataset类
    5. class MyDataset(Dataset):
    6. def __init__(self, x, y):
    7. self.data = torch.from_numpy(x).float()
    8. self.label = torch.LongTensor(y)
    9. def __getitem__(self, idx):
    10. return self.data[idx], self.label[idx], idx
    11. def __len__(self):
    12. return len(self.data)
    13. Train_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    14. Train_label = np.array([10, 11, 12, 13])
    15. TrainDataset = MyDataset(Train_data, Train_label) # 创建实例对象
    16. print('len:', len(TrainDataset))
    17. # 创建DataLoader
    18. loader = torch.utils.data.DataLoader(
    19. dataset=TrainDataset,
    20. batch_size=2,
    21. shuffle=False,
    22. num_workers=0,
    23. drop_last=False)
    24. # 按batchsize打印数据
    25. for batch_idx, (data, label, index) in enumerate(loader):
    26. print('batch_idx:',batch_idx, '\ndata:',data, '\nlabel:',label, '\nindex:',index)
    27. print('---------')

    输出:

    1. len: 4
    2. batch_idx: 0
    3. data: tensor([[1., 2., 3.],
    4. [4., 5., 6.]])
    5. label: tensor([10, 11])
    6. index: tensor([0, 1])
    7. ---------
    8. batch_idx: 1
    9. data: tensor([[ 7., 8., 9.],
    10. [10., 11., 12.]])
    11. label: tensor([12, 13])
    12. index: tensor([2, 3])
    13. ---------

    4-- 参考

    参考链接1

  • 相关阅读:
    Keras深度学习实战——基于Inception v3实现性别分类
    从零开始安装并运行YOLOv5
    Meetup 回顾|Data Infra 研究社第十六期(含资料发布)
    深入浅出 《if的表达式》
    PHP NBA球迷俱乐部系统Dreamweaver开发mysql数据库web结构php编程计算机网页
    ViT:拉开Trasnformer在图像领域正式挑战CNN的序幕 | ICLR 2021
    JavaScript 34 JavaScript 随机
    通过termux tailscale huggingface 来手把手一步一步在手机上部署LLAMA2-7b和LLAMA2-70b大模型
    vue修改子组件中的el-input的placeholder字体颜色
    举一反三刷穿字符串加减法类型题目:牛客BM86 大数加法、LeetCode-445. 两数相加 II、LeetCode-2. 两数相加
  • 原文地址:https://blog.csdn.net/weixin_43863869/article/details/125602643
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号