当前位置: 首页 > news >正文

GAN的损失函数和二元交叉熵损失的对应及代码

以下解释为GPT生成

在这里插入图片描述
在这里插入图片描述在这里插入图片描述
这里有个问题,使用二元交叉熵,的时候生成器的损失如何体现
在这里插入图片描述
看代码

import torch
import torch.nn as nn
import torch.optim as optim# 设置设备为GPU或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义生成器 (Generator)
class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, output_size),nn.Tanh()  # 输出范围在 [-1, 1] 之间)def forward(self, x):return self.model(x)# 定义判别器 (Discriminator)
class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, output_size),nn.Sigmoid()  # 输出概率 [0, 1])def forward(self, x):return self.model(x)# 超参数设置
input_size = 100  # 生成器输入噪声向量的维度
hidden_size = 128
output_size = 1   # 判别器的输出是一个概率
data_size = 784   # 假设输入数据的维度(例如,MNIST 图片 28x28 展开成 784 维向量)# 初始化生成器和判别器
G = Generator(input_size, hidden_size, data_size).to(device)
D = Discriminator(data_size, hidden_size, output_size).to(device)# 定义损失函数(二元交叉熵损失)
criterion = nn.BCELoss()# 优化器
lr = 0.0002
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)# 生成随机噪声的函数(用于生成器)
def generate_noise(batch_size, input_size):return torch.randn(batch_size, input_size).to(device)# 假设我们有一个简单的生成数据和真实数据的函数
def get_real_data(batch_size):# 这里我们用随机生成的假数据来模拟真实数据return torch.randn(batch_size, data_size).to(device)# 训练步骤
epochs = 1000
batch_size = 64for epoch in range(epochs):# 训练判别器real_data = get_real_data(batch_size)  # 获取真实数据noise = generate_noise(batch_size, input_size)  # 生成噪声fake_data = G(noise)  # 生成数据# 判别器的目标:正确区分真实数据和生成数据real_labels = torch.ones(batch_size, 1).to(device)  # 真实数据的标签为1fake_labels = torch.zeros(batch_size, 1).to(device)  # 生成数据的标签为0# 判别器对真实数据的损失outputs_real = D(real_data)D_loss_real = criterion(outputs_real, real_labels)# 判别器对生成数据的损失outputs_fake = D(fake_data.detach())  # 对生成数据的判别(生成数据不传递梯度给生成器)D_loss_fake = criterion(outputs_fake, fake_labels)# 判别器总损失D_loss = D_loss_real + D_loss_fake# 更新判别器optimizer_D.zero_grad()D_loss.backward()optimizer_D.step()# 训练生成器noise = generate_noise(batch_size, input_size)  # 生成新的噪声fake_data = G(noise)  # 生成新的假数据# 生成器的目标:欺骗判别器,让判别器认为生成的数据是真实的outputs_fake = D(fake_data)G_loss = criterion(outputs_fake, real_labels)  # 生成器希望生成的数据被判为真实数据,因此标签设为1# 更新生成器optimizer_G.zero_grad()G_loss.backward()optimizer_G.step()# 每隔一段时间打印损失if epoch % 100 == 0:print(f"Epoch [{epoch}/{epochs}] | D Loss: {D_loss.item():.4f} | G Loss: {G_loss.item():.4f}")

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • seafaring靶场漏洞测试攻略
  • 外观模式详解:如何为复杂系统构建简洁的接口
  • 【疑难杂症2024-005】docker-compose中设置容器的ip为固定ip后,服务无法启动
  • TCP 拥塞控制:一场网络数据的交通故事
  • 原生C++下模拟.Net平台的 DataTable,DataRow,只有部分功能,以后转Qt版和Python版。
  • 实战案例(5)防火墙通过跨三层MAC识别功能控制三层核心下面的终端
  • Linux(CentOS8)服务器安装RabbitMQ
  • Android DPC模式多开 APP
  • 力扣(leetcode)每日一题 1184 公交站间的距离
  • 为什么Node.js不适合CPU密集型应用?
  • 算法打卡:第十章 单调栈part01
  • 通过adb命令打开手机usb调试
  • Android Studio新建工程(Java语言环境)
  • 【建设方案】固定资产信息系统建设方案(功能清单列表2024word原件)
  • 9.12 TFTP通信
  • canvas绘制圆角头像
  • Gradle 5.0 正式版发布
  • httpie使用详解
  • Java 最常见的 200+ 面试题:面试必备
  • JavaScript异步流程控制的前世今生
  • Mithril.js 入门介绍
  • python_bomb----数据类型总结
  • Rancher如何对接Ceph-RBD块存储
  • 从重复到重用
  • 高性能JavaScript阅读简记(三)
  • 基于Dubbo+ZooKeeper的分布式服务的实现
  • 快速体验 Sentinel 集群限流功能,只需简单几步
  • 山寨一个 Promise
  • 使用 QuickBI 搭建酷炫可视化分析
  • nb
  • 好程序员web前端教程分享CSS不同元素margin的计算 ...
  • ​​​​​​​开发面试“八股文”:助力还是阻力?
  • #绘制圆心_R语言——绘制一个诚意满满的圆 祝你2021圆圆满满
  • (1/2)敏捷实践指南 Agile Practice Guide ([美] Project Management institute 著)
  • (solr系列:一)使用tomcat部署solr服务
  • (分享)自己整理的一些简单awk实用语句
  • (附源码)ssm高校志愿者服务系统 毕业设计 011648
  • (理论篇)httpmoudle和httphandler一览
  • (实战)静默dbca安装创建数据库 --参数说明+举例
  • (算法设计与分析)第一章算法概述-习题
  • (小白学Java)Java简介和基本配置
  • (转)ORM
  • (最完美)小米手机6X的Usb调试模式在哪里打开的流程
  • ... 是什么 ?... 有什么用处?
  • .gitignore文件—git忽略文件
  • .Net Core 中间件验签
  • .net MySql
  • .NET/C# 使用 #if 和 Conditional 特性来按条件编译代码的不同原理和适用场景
  • .NET企业级应用架构设计系列之开场白
  • .NET中分布式服务
  • @component注解的分类
  • @vue-office/excel 解决移动端预览excel文件触发软键盘
  • [12] 使用 CUDA 进行图像处理
  • [240812] X-CMD 发布 v0.4.5:更新 gtb、cd、chat、hashdir 模块功能
  • [AHK] WinHttpRequest.5.1报错 0x80092004 找不到对象或属性