当前位置: 首页 > news >正文

RobustVideoMatting 预测图片

改为了推理图片,文件夹的图片尺寸必须一样,否则会报错

针对复杂场景,效果也不好,比如被另一个人遮挡,前面还挂了围脖,背了包包,抱着小孩

"""
python inference.py \--variant mobilenetv3 \--checkpoint "CHECKPOINT" \--device cuda \--input-source "input.mp4" \--output-type video \--output-composition "composition.mp4" \--output-alpha "alpha.mp4" \--output-foreground "foreground.mp4" \--output-video-mbps 4 \--seq-chunk 1
"""import torch
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from typing import Optional, Tuple
from tqdm.auto import tqdmfrom inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriterdef convert_video(model,input_source: str,input_resize: Optional[Tuple[int, int]] = None,downsample_ratio: Optional[float] = None,output_type: str = 'video',output_composition: Optional[str] = None,output_alpha: Optional[str] = None,output_foreground: Optional[str] = None,output_video_mbps: Optional[float] = None,seq_chunk: int = 1,num_workers: int = 0,progress: bool = True,device: Optional[str] = None,dtype: Optional[torch.dtype] = None):assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'assert seq_chunk >= 1, 'Sequence chunk must be >= 1'assert num_workers >= 0, 'Number of workers must be >= 0'# Initialize transformif input_resize is not None:transform = transforms.Compose([transforms.Resize(input_resize[::-1]),transforms.ToTensor()])else:transform = transforms.ToTensor()# Initialize readerif os.path.isfile(input_source):source = VideoReader(input_source, transform)else:source = ImageSequenceReader(input_source, transform)reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)# Initialize writersif output_type == 'video':frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30output_video_mbps = 1 if output_video_mbps is None else output_video_mbpsif output_composition is not None:writer_com = VideoWriter(path=output_composition,frame_rate=frame_rate,bit_rate=int(output_video_mbps * 1000000))if output_alpha is not None:writer_pha = VideoWriter(path=output_alpha,frame_rate=frame_rate,bit_rate=int(output_video_mbps * 1000000))if output_foreground is not None:writer_fgr = VideoWriter(path=output_foreground,frame_rate=frame_rate,bit_rate=int(output_video_mbps * 1000000))else:if output_composition is not None:writer_com = ImageSequenceWriter(output_composition, 'png')if output_alpha is not None:writer_pha = ImageSequenceWriter(output_alpha, 'png')if output_foreground is not None:writer_fgr = ImageSequenceWriter(output_foreground, 'png')# Inferencemodel = model.eval()if device is None or dtype is None:param = next(model.parameters())dtype = param.dtypedevice = param.deviceif (output_composition is not None) and (output_type == 'video'):bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)try:with torch.no_grad():bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)rec = [None] * 4for src in reader:if downsample_ratio is None:downsample_ratio = auto_downsample_ratio(*src.shape[2:])src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]fgr, pha, *rec = model(src, *rec, downsample_ratio)if output_foreground is not None:writer_fgr.write(fgr[0])if output_alpha is not None:writer_pha.write(pha[0])if output_composition is not None:if output_type == 'video':com = fgr * pha + bgr * (1 - pha)else:fgr = fgr * pha.gt(0)com = torch.cat([fgr, pha], dim=-3)writer_com.write(com[0])bar.update(src.size(1))finally:# Clean upif output_composition is not None:writer_com.close()if output_alpha is not None:writer_pha.close()if output_foreground is not None:writer_fgr.close()def auto_downsample_ratio(h, w):"""Automatically find a downsample ratio so that the largest side of the resolution be 512px."""return min(512 / max(h, w), 1)class Converter:def __init__(self, variant: str, checkpoint: str, device: str):self.model = MattingNetwork(variant).eval().to(device)self.model.load_state_dict(torch.load(checkpoint, map_location=device))self.model = torch.jit.script(self.model)self.model = torch.jit.freeze(self.model)self.device = devicedef convert(self, *args, **kwargs):convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)if __name__ == '__main__':import argparsefrom model import MattingNetwork"""python inference.py \--variant mobilenetv3 \--checkpoint "CHECKPOINT" \--device cuda \--input-source "input.mp4" \--output-type video \--output-composition "composition.mp4" \--output-alpha "alpha.mp4" \--output-foreground "foreground.mp4" \--output-video-mbps 4 \--seq-chunk 1"""parser = argparse.ArgumentParser()parser.add_argument('--variant', type=str, default='resnet50', choices=['mobilenetv3', 'resnet50'])parser.add_argument('--checkpoint', type=str, default=r'D:\project\fenge\jacke121-rvm_128_json\model_a\rvm_resnet50.pth')parser.add_argument('--device', type=str,default='cuda')parser.add_argument('--input-source', type=str, default=r'C:\Users\Administrator\Documents\WeChat Files\libanggeng\FileStorage\File\2023-11\koutu\weilanliandai\aa')parser.add_argument('--input-resize', type=int, default=None, nargs=2)parser.add_argument('--downsample-ratio', type=float)parser.add_argument('--output-composition', type=str,default='output-composition')parser.add_argument('--output-alpha', type=str,default='output-alpha')parser.add_argument('--output-foreground', type=str,default='output-foreground')parser.add_argument('--output-type', type=str, default='png_sequence', choices=['video', 'png_sequence'])parser.add_argument('--output-video-mbps', type=int, default=1)parser.add_argument('--seq-chunk', type=int, default=1)parser.add_argument('--num-workers', type=int, default=0)parser.add_argument('--disable-progress', action='store_true')args = parser.parse_args()converter = Converter(args.variant, args.checkpoint, args.device)converter.convert(input_source=args.input_source,input_resize=args.input_resize,downsample_ratio=args.downsample_ratio,output_type=args.output_type,output_composition=args.output_composition,output_alpha=args.output_alpha,output_foreground=args.output_foreground,output_video_mbps=args.output_video_mbps,seq_chunk=args.seq_chunk,num_workers=args.num_workers,progress=not args.disable_progress)

相关文章:

  • centos 6.10 安装 svn1.14.2
  • 自己动手实现一个深度学习算法——六、与学习相关的技巧
  • 【matlab】KMeans KMeans++实现手写数字聚类
  • 集成Line、Facebook、Twitter、Google、微信、QQ、微博、支付宝的三方登录sdk
  • 开启创造力之门:掌握Vue中Slot插槽的使用技巧与灵感
  • Intellij IDEA 内存设置的问题 及解决
  • Python高级语法----Python C扩展与性能优化
  • windiws docker 部署jar window部署docker 转载
  • 微软允许OEM对Win10不提供关闭Secure Boot
  • pytorch tensor数据类型转换为python数据
  • 使用 typescript + express 创建 NodeJs 后端服务
  • Angular 依赖注入介绍及使用(五)
  • 使用 webpack 打包 express 应用
  • Flutter笔记:getX库中的GetView中间件
  • 【计算思维】少儿编程蓝桥杯青少组计算思维题考试真题及解析B
  • 【391天】每日项目总结系列128(2018.03.03)
  • ECMAScript6(0):ES6简明参考手册
  • github从入门到放弃(1)
  • Linux后台研发超实用命令总结
  • Python 使用 Tornado 框架实现 WebHook 自动部署 Git 项目
  • 基于OpenResty的Lua Web框架lor0.0.2预览版发布
  • 基于组件的设计工作流与界面抽象
  • 记一次删除Git记录中的大文件的过程
  • 聊聊hikari连接池的leakDetectionThreshold
  • 人脸识别最新开发经验demo
  • 移动端 h5开发相关内容总结(三)
  • 掌握面试——弹出框的实现(一道题中包含布局/js设计模式)
  • 智能合约开发环境搭建及Hello World合约
  • 如何通过报表单元格右键控制报表跳转到不同链接地址 ...
  • ​ssh-keyscan命令--Linux命令应用大词典729个命令解读
  • #前后端分离# 头条发布系统
  • (Git) gitignore基础使用
  • (仿QQ聊天消息列表加载)wp7 listbox 列表项逐一加载的一种实现方式,以及加入渐显动画...
  • (分布式缓存)Redis分片集群
  • (汇总)os模块以及shutil模块对文件的操作
  • (蓝桥杯每日一题)love
  • (六)激光线扫描-三维重建
  • (四)图像的%2线性拉伸
  • (循环依赖问题)学习spring的第九天
  • (一)eclipse Dynamic web project 工程目录以及文件路径问题
  • (转)http-server应用
  • (转)ORM
  • .htaccess 强制https 单独排除某个目录
  • .net core 实现redis分片_基于 Redis 的分布式任务调度框架 earth-frost
  • .net refrector
  • .NET 中创建支持集合初始化器的类型
  • :中兴通讯为何成功
  • @AutoConfigurationPackage的使用
  • @Service注解让spring找到你的Service bean
  • @Validated和@Valid校验参数区别
  • @vue/cli脚手架
  • [ CTF ] WriteUp- 2022年第三届“网鼎杯”网络安全大赛(白虎组)
  • [Angularjs]ng-select和ng-options
  • [Cocoa]iOS 开发者账户,联机调试,发布应用事宜
  • [delphi]保证程序只运行一个实例