• 关于#Python#生成对抗网络代码的问题,如何解决?(相关搜索:python代码|数据集|训练集)


    关注 码龄 粉丝数 原力等级 -- 被采纳 被点赞 采纳率 SilyaSophie 2024-04-01 22:15 采纳率: 60% 浏览 6 首页/ 编程语言 / 关于#Python#生成对抗网络代码的问题,如何解决?(相关搜索:python代码|数据集|训练集) pythonkeras生成对抗网络 代码如下: from __future__ import print_function, division from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Conv1D,GRU,Dropout,InputLayer from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import matplotlib.pyplot as plt import sys import numpy as np import pandas as pd import import2023 class GAN: def __init__(self): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # The generator takes noise as input and generates imgs z = Input(shape=(self.latent_dim,)) img = self.generator(z) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity validity = self.discriminator(img) # The combined model (stacked generator and discriminator) # Trains the generator to fool the discriminator self.combined = Model(z, validity) self.combined.compile(loss='mse', optimizer=optimizer) def build_generator(self): model = Sequential() model.add(GRU(256, input_shape=self.img_shape, activation='relu')) #model.add(Dense(256, input_dim=self.latent_dim)) model.add(Dense(512, activation='relu')) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(Dense(512, activation='relu')) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(Dense(1024, activation='relu')) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() noise = Input(shape=(self.latent_dim,1,1)) img = model(noise) return Model(noise,img) def build_discriminator(self): model = Sequential() #model.add(Flatten(input_shape=self.img_shape)) model.add(InputLayer(input_shape=(32, 64))) #model.add(Dense(256,input_shape=self.img_shape,activation='relu')) #Conv2D(filters, kernel_size, data_format='NHWC') model.add(Conv1D(1024, kernel_size=3, strides=2, padding='same', data_format='channels_first',activation='relu')) model.add(Dense(512)) model.add(Dense(512, activation='relu')) model.add(Dense(256)) model.add(Dense(64, activation='relu')) model.add(Dense(1, activation='sigmoid')) model.summary() img = Input(shape=self.img_shape) validity = model(img) return Model(img,validity) def train(self, epochs, batch_size=128, sample_interval=50): # Load the dataset # csv文件路径 csv_path_train = 'E:/dataset/CICIoT2023/benign.csv' # 读取数据 X_train= pd.read_csv(csv_path_train) # Rescale -1 to 1 X_train = X_train / 127.5 - 1. #X_train = np.expand_dims(X_train, axis=3) X_train = np.reshape(X_train, (-1, 100, 46)) print(X_train.shape) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) print(valid.shape) print(fake.shape) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Select a random batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] print(imgs.shape) noise = np.random.normal(0, 1, (batch_size, self.latent_dim,1)) print(noise.shape) # Generate a batch of new images gen_imgs = self.generator.predict(noise) # Train the discriminator d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- noise = np.random.normal(0, 1, (batch_size, self.latent_dim,1)) # Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise, valid) # Plot the progress print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss)) # # If at save interval => save generated image samples # if epoch % sample_interval == 0: # self.sample_images(epoch) if __name__ == '__main__': gan = GAN() gan.train(epochs=30000, batch_size=32, sample_interval=200) 运行结果及报错: 展开全部 收起 写回答 好问题 0 提建议 追加酬金 关注问题 微信扫一扫 点击复制链接 分享 邀请回答 编辑 收藏 删除 结题 收藏 举报 追加酬金 (90%的用户在追加酬金后获得了解决方案) 当前问题酬金 ¥ 0 (可追加 ¥500) ¥ 15¥ 20¥ 50¥ 100¥ 200 支付方式 扫码支付 二维码出错 点击刷新 支付金额 15 元 提供问题酬金的用户不参与问题酬金结算和分配 支付即为同意 《付费问题酬金结算规则》 结题 再想想 删除 再等等

    3条回答 默认 最新

    • 码农阿豪 Java领域优质创作者 2024-04-01 22:30
      关注
      让阿豪来帮你解答,本回答参考chatgpt3.5编写提供,如果还有疑问可以评论或留言
      该代码实现的是一个基于深度学习的生成对抗网络(GAN)模型,用于生成28x28像素的手写数字图像。
      1. 问题:这段代码所实现的功能是什么? 回答: 该段代码实现的是一个基于深度学习的生成对抗网络(GAN)模型,用于生成28x28像素的手写数字图像。
      2. 问题:这段代码中所使用的算法或技术是什么? 回答: 该段代码中使用了生成对抗网络(GAN)算法,包括生成器和判别器两个部分。
      3. 问题:这段代码中的每个函数或方法都是用来做什么的? 回答:
      4. __init__(self):初始化生成对抗网络模型。
      5. build_generator(self):构建生成器模型,生成28x28的图像。
      6. build_discriminator(self):构建判别器模型,判别输入的图像是真实的还是虚假的。
      7. GAN类中的其他函数和方法包括训练模型所需用到的定义和编译模型、导入所需的库、以及实现模型训练、评估等功能的代码。
      8. 问题:这段代码在实现功能时存在哪些缺陷或不足?如何改进? 回答: 该段代码在实现功能时没有明显的缺陷或不足。但是,它只实现了生成28x28像素的手写数字图像,不能应用于其他类型的图像生成。可以改进的方向包括增加模型的泛化能力,使其能够生成其他类型的图像,以及提高模型的训练速度和准确性。可以探索使用更先进的神经网络和优化算法,如增加卷积层、使用批量归一化、使用更优秀的优化算法等。
      评论
      SilyaSophie 2024-04-02 19:32

      谢谢,不过第24行代码有定义:self.img_shape = (self.img_rows, self.img_cols, self.channels)

      回复
    • GISer Liu 2024-04-01 22:35
      关注

      该回答引用自GPT-3.5,由博主GISer Liu编写:

      根据您提供的代码和问题描述,有几个需要注意的地方:

      1. 数据集问题 :您的代码中使用了mnist数据集,但是您希望使用的是自己的数据集。您需要替换mnist数据集加载的部分,改为加载您自己的数据集。根据您的代码,您的数据集似乎是一个CSV文件,您可以使用pd.read_csv来加载数据集,并对数据进行预处理,确保数据的维度和范围符合模型的要求。
      2. 输入形状问题 :您的生成器模型和判别器模型的输入形状不一致。生成器模型的输入形状为(self.latent_dim,1,1),而判别器模型的输入形状为self.img_shape。您需要确保两个模型的输入形状一致,以便正确地连接它们。
      3. 生成器模型问题 :生成器模型的输出形状应该与图像形状一致,即(self.img_rows, self.img_cols, self.channels),但是您的生成器模型最后一层输出形状为self.img_shape,这可能导致模型无法正确生成图像。您需要修改生成器模型的最后一层,使其输出形状符合图像形状。
      4. 训练过程问题 :在训练过程中,您需要确保生成器和判别器的输入形状一致,并且正确地传递真实图像和生成图像。您还需要注意确保输入数据的范围正确,通常在训练前会对数据进行归一化处理。
        基于以上问题,您可以按照以下步骤修改您的代码:
      5. 替换数据集加载部分,确保加载自己的数据集并进行预处理。
      6. 修改生成器模型的最后一层,使其输出形状为(self.img_rows, self.img_cols, self.channels)
      7. 修改生成器和判别器模型的输入形状,使其一致。
      8. 确保在训练过程中传递正确的输入数据。
        这些修改应该能够解决您遇到的问题。如果您需要进一步的帮助或有其他问题,请随时提出。

      如果该回答解决了您的问题,请采纳!如果没有,请私信联系或评论您的疑惑

      评论
      SilyaSophie 2024-04-02 19:32

      谢谢,不过第24行代码有定义:self.img_shape = (self.img_rows, self.img_cols, self.channels)

      回复
      GISer Liu 2024-04-02 22:13

      根据您提供的代码和问题描述,有几个需要注意的地方:

      1. 数据集问题 :您的代码中使用了mnist数据集,但是您希望使用的是自己的数据集。您需要替换mnist数据集加载的部分,改为加载您自己的数据集。根据您的代码,您的数据集似乎是一个CSV文件,您可以使用pd.read_csv来加载数据集,并对数据进行预处理,确保数据的维度和范围符合模型的要求。
      2. 输入形状问题 :您的生成器模型和判别器模型的输入形状不一致。生成器模型的输入形状为(self.latent_dim, 1, 1),而判别器模型的输入形状为(batch_size, 100, 46)。您需要调整生成器模型的输入形状,使其与判别器模型的输入形状一致。
        针对这两个问题,您可以按照以下步骤进行修改:
      3. 数据集加载和预处理 :首先,您需要使用pd.read_csv加载您的CSV文件,并对数据进行预处理,确保数据的维度和范围符合模型的要求。例如,您可以使用以下代码加载和预处理数据集:
        pythonCopy code
        csv_path_train = 'E:/dataset/CICIoT2023/benign.csv'
        X_train = pd.read_csv(csv_path_train)
        X_train = X_train / 127.5 - 1.
        X_train = np.reshape(X_train, (-1, 100, 46, 1))
        
      4. 调整生成器模型的输入形状 :您需要调整生成器模型的输入形状,使其与判别器模型的输入形状一致。修改生成器模型的输入形状如下:
        pythonCopy code
        noise = Input(shape=(self.latent_dim, 1))
        
        经过以上修改,您的代码应该能够正确加载自己的数据集,并且生成器模型和判别器模型的输入形状也符合要求。

      回复
      SilyaSophie 回复 GISer Liu 2024-04-04 18:10

      你好,按照您的建议修改代码之后,还是同样的报错。

      回复
      展开全部4条评论
    • CSDN-Ada助手 CSDN-AI 官方账号 2024-04-03 02:27
      关注

      【相关推荐】




      如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^

      展开全部

      评论
    编辑
    预览

    报告相同问题?

  • 相关阅读:
    DNS域名解析
    【技术】Spring Boot 将 Word 转换为 PDF 2.0 版本
    跨域请求的方法
    【开发篇】三、web下单元测试与mock数据
    Java代码审计——XML 外部实体注入(XXE)
    基于JAVA中小学生错题管理系统计算机毕业设计源码+系统+mysql数据库+lw文档+部署
    性能测试,如何做压力测试?压力测试实施,避免背锅提升效率(二)
    WebShell 木马免杀过WAF
    【一起读源码】1. Java 中元组 Tuple
    HTML5使用html2canvas转化为图片,然后再转为base64.
  • 原文地址:https://ask.csdn.net/questions/8082499