当前位置: 首页 > news >正文

使用PyTorch进行图像风格迁移:基于VGG19实现

图像风格迁移(Neural Style Transfer, NST)是深度学习中一个令人着迷的应用,它能够将一张图像的风格应用到另一张图像上。例如,能够将梵高的画风应用到一张普通照片上。本文将详细解释如何使用PyTorch进行风格迁移,逐步分析代码,并讲解其中的关键技术。

1. 环境准备

在开始之前,确保安装了必要的库:

pip install torch torchvision pillow

2. 模型缓存目录设置

为了加速模型的加载,我们可以通过设置环境变量TORCH_HOME来指定模型缓存目录,避免每次运行代码时重新下载模型:

os.environ['TORCH_HOME'] = './model_directory'  # 你可以根据需要自定义目录

3. 加载图像

加载图像并进行预处理是风格迁移中的重要步骤。我们需要将图像转换为张量并进行归一化处理,以便与预训练的VGG19模型匹配:

def load_image(image_path, max_size=400):image = Image.open(image_path).convert('RGB')size = min(max_size, max(image.size))transform = transforms.Compose([transforms.Resize((size, size)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])return transform(image).unsqueeze(0)

在这里,我们将图像调整为不大于400像素的正方形,并将其转换为适合VGG19模型输入的格式。

4. VGG19模型的特征提取

风格迁移的核心思想是将内容图像的高层次特征与风格图像的低层次特征结合。我们使用VGG19模型的前21层来提取图像的特征:

class VGG(nn.Module):def __init__(self):super(VGG, self).__init__()self.features = vgg19(pretrained=True).features[:21].eval()def forward(self, x):features = []for i, layer in enumerate(self.features):x = layer(x)if i in {0, 5, 10, 19, 21}:features.append(x)return features

5. 内容与风格损失

内容损失衡量生成图像与内容图像的特征差异,而风格损失则是基于Gram矩阵来衡量生成图像与风格图像的差异。

  • 内容损失:
class ContentLoss(nn.Module):def __init__(self, target):super(ContentLoss, self).__init__()self.target = target.detach()def forward(self, input):return nn.functional.mse_loss(input, self.target)
  • 风格损失:
class StyleLoss(nn.Module):def __init__(self, target):super(StyleLoss, self).__init__()self.target = self.gram_matrix(target).detach()def gram_matrix(self, input):batch_size, channels, height, width = input.size()features = input.view(batch_size * channels, height * width)G = torch.mm(features, features.t())return G.div(batch_size * channels * height * width)def forward(self, input):G = self.gram_matrix(input)return nn.functional.mse_loss(G, self.target)

6. 图像风格迁移算法

核心算法将内容图像初始化为输入图像,并通过多次迭代优化,使其逐步接近目标风格图像,同时保持内容的完整性。我们使用LBFGS优化器来实现这一过程:

def style_transfer(content_img, style_img, num_steps=1000, style_weight=1e9, content_weight=1):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")content_img = content_img.to(device)style_img = style_img.to(device)model = VGG().to(device)style_features = model(style_img)content_features = model(content_img)input_img = content_img.clone().requires_grad_(True).to(device)optimizer = optim.LBFGS([input_img])style_losses = []content_losses = []for sf, cf in zip(style_features, content_features):content_losses.append(ContentLoss(cf))style_losses.append(StyleLoss(sf))run = [0]while run[0] <= num_steps:def closure():optimizer.zero_grad()input_features = model(input_img)content_loss = 0style_loss = 0for cl, input_f in zip(content_losses, input_features):content_loss += content_weight * cl(input_f)for sl, input_f in zip(style_losses, input_features):style_loss += style_weight * sl(input_f)loss = content_loss + style_lossloss.backward()run[0] += 1if run[0] % 50 == 0:print(f'Step {run[0]}, Content Loss: {content_loss.item():4f}, Style Loss: {style_loss.item():4f}')return lossoptimizer.step(closure)return input_img

7. 结果保存

生成的图像需要去除归一化并保存为常规图片格式:

def save_image(tensor, path):image = tensor.clone().detach()image = image.squeeze(0)image = transforms.ToPILImage()(image)image.save(path)

8. 主函数执行

整个过程可以通过主函数来执行,加载图像、进行风格迁移并保存结果:

if __name__ == '__main__':content_image_path = 'content_image.png'style_image_path = 'style_image.png'output_image_path = 'output_image.jpg'content_img = load_image(content_image_path)style_img = load_image(style_image_path)result = style_transfer(content_img, style_img)save_image(result, output_image_path)print(f"风格迁移完成,图像已保存为 {output_image_path}")

总结

本文展示了如何使用PyTorch和VGG19模型实现图像风格迁移。通过合理设置内容和风格损失的权重,我们可以生成既保留内容图像结构又具有风格图像艺术风格的全新图像。

完整代码

github:https://github.com/Yolumia/Image_style_transfer_base_vgg19/

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 多目标优化算法求解LSMOP(Large-Scale Multi-Objective Optimization Problem)测试集,MATLAB代码
  • Windows和Mac命令窗快速打开文件夹
  • vue 项目自适应 配置 px转rem 的插件postcss-pxtorem
  • 数据中台建设(六)—— 数据开发-提取数据价值
  • Java实现建造者模式和源码中的应用
  • 大棚分割数据集,40765对影像,16.9g数据量,0.8米高分二,纯手工标注(arcgis标注)的大规模农业大棚分割数据集。
  • 使用Flux以文生图
  • 【QT】常用类
  • php AEAD_AES_256_GCM算法 解密
  • 38. 如何在Spring Boot项目中集成MyBatis-Plus?
  • 读构建可扩展分布式系统:方法与实践04应用服务
  • 低功耗蓝牙模块在健身器材中的应用,让健身体验更智能
  • 【GoMate框架案例】讯飞大模型RAG智能问答挑战赛top10 Baseline
  • vue3常见的bug 修复bug
  • 代码随想录算法训练营day36
  • 《深入 React 技术栈》
  • 8年软件测试工程师感悟——写给还在迷茫中的朋友
  • angular2开源库收集
  • Apache的基本使用
  • js正则,这点儿就够用了
  • Mysql数据库的条件查询语句
  • Netty 4.1 源代码学习:线程模型
  • Next.js之基础概念(二)
  • Phpstorm怎样批量删除空行?
  • scrapy学习之路4(itemloder的使用)
  • seaborn 安装成功 + ImportError: DLL load failed: 找不到指定的模块 问题解决
  • Vue2.x学习三:事件处理生命周期钩子
  • vue-cli在webpack的配置文件探究
  • 从tcpdump抓包看TCP/IP协议
  • 关于for循环的简单归纳
  • 前嗅ForeSpider中数据浏览界面介绍
  • 我建了一个叫Hello World的项目
  • 硬币翻转问题,区间操作
  • ​草莓熊python turtle绘图代码(玫瑰花版)附源代码
  • #Spring-boot高级
  • #数据结构 笔记三
  • (1)安装hadoop之虚拟机准备(配置IP与主机名)
  • (bean配置类的注解开发)学习Spring的第十三天
  • (C语言)共用体union的用法举例
  • (附源码)python房屋租赁管理系统 毕业设计 745613
  • (附源码)计算机毕业设计SSM智慧停车系统
  • (欧拉)openEuler系统添加网卡文件配置流程、(欧拉)openEuler系统手动配置ipv6地址流程、(欧拉)openEuler系统网络管理说明
  • (转)nsfocus-绿盟科技笔试题目
  • (转)可以带来幸福的一本书
  • **CI中自动类加载的用法总结
  • **PHP二维数组遍历时同时赋值
  • ... fatal error LINK1120:1个无法解析的外部命令 的解决办法
  • .gitattributes 文件
  • .jks文件(JAVA KeyStore)
  • .NET 8 中引入新的 IHostedLifecycleService 接口 实现定时任务
  • .Net CF下精确的计时器
  • .net core webapi Startup 注入ConfigurePrimaryHttpMessageHandler
  • .NET 命令行参数包含应用程序路径吗?
  • .NET命令行(CLI)常用命令
  • .NET下的多线程编程—1-线程机制概述