[github优秀AI项目]实现4K60帧视频人体实时抠图 您所在的位置:网站首页 抠像技术 [github优秀AI项目]实现4K60帧视频人体实时抠图

[github优秀AI项目]实现4K60帧视频人体实时抠图

2023-11-14 22:09| 来源: 网络整理| 查看: 265

项目地址:

https://github.com/PeterL1n/RobustVideoMatting

文章:

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

PyTorch、TensorFlow、TensorFlow中的强大视频抠图功能。js,ONNX,CoreML!

稳定视频抠像 (RVM)

 论文 Robust High-Resolution Video Matting with Temporal Guidance 的官方 GitHub 库。RVM 专为稳定人物视频抠像设计。不同于现有神经网络将每一帧作为单独图片处理,RVM 使用循环神经网络,在处理视频流时有时间记忆。RVM 可在任意视频上做实时高清抠像。在 Nvidia GTX 1080Ti 上实现 4K 76FPS 和 HD 104FPS。此研究项目来自字节跳动。

展示视频

观看展示视频 (YouTube, Bilibili),了解模型能力。

视频中的所有素材都提供下载,可用于测试模型:Google Drive

Demo 网页: 在浏览器里看摄像头抠像效果,展示模型内部循环记忆值。Colab: 用我们的模型转换你的视频。 下载

推荐在通常情况下使用 MobileNetV3 的模型。ResNet50 的模型大很多,效果稍有提高。我们的模型支持很多框架。详情请阅读推断文档。

框架下载备注PyTorchrvm_mobilenetv3.pthrvm_resnet50.pth官方 PyTorch 模型权值。文档TorchHub无需手动下载。更方便地在你的 PyTorch 项目里使用此模型。文档TorchScriptrvm_mobilenetv3_fp32.torchscriptrvm_mobilenetv3_fp16.torchscriptrvm_resnet50_fp32.torchscriptrvm_resnet50_fp16.torchscript若需在移动端推断,可以考虑自行导出 int8 量化的模型。文档ONNXrvm_mobilenetv3_fp32.onnxrvm_mobilenetv3_fp16.onnxrvm_resnet50_fp32.onnxrvm_resnet50_fp16.onnx在 ONNX Runtime 的 CPU 和 CUDA backend 上测试过。提供的模型用 opset 12。文档,导出TensorFlowrvm_mobilenetv3_tf.ziprvm_resnet50_tf.zipTensorFlow 2 SavedModel 格式。文档TensorFlow.jsrvm_mobilenetv3_tfjs_int8.zip在网页上跑模型。展示,示范代码CoreMLrvm_mobilenetv3_1280x720_s0.375_fp16.mlmodelrvm_mobilenetv3_1280x720_s0.375_int8.mlmodelrvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodelrvm_mobilenetv3_1920x1080_s0.25_int8.mlmodelCoreML 只能导出固定分辨率,其他分辨率可自行导出。支持 iOS 13+。s 代表下采样比。文档,导出

所有模型可在 Google Drive 或百度网盘(密码: gym7)上下载。

PyTorch 范例 1 安装 Python 库: pip install -r requirements_inference.txt 2 加载模型: import torch from model import MattingNetwork model = MattingNetwork('mobilenetv3').eval().cuda() # 或 "resnet50" model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) 3 若只需要做视频抠像处理,我们提供简单的 API: from inference import convert_video convert_video( model, # 模型,可以加载到任何设备(cpu 或 cuda) input_source='input.mp4', # 视频文件,或图片序列文件夹 output_type='video', # 可选 "video"(视频)或 "png_sequence"(PNG 序列) output_composition='com.mp4', # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径 output_alpha="pha.mp4", # [可选项] 输出透明度预测 output_foreground="fgr.mp4", # [可选项] 输出前景预测 output_video_mbps=4, # 若导出视频,提供视频码率 downsample_ratio=None, # 下采样比,可根据具体视频调节,或 None 选择自动 seq_chunk=12, # 设置多帧并行计算 ) 4 或自己写推断逻辑: from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from inference_utils import VideoReader, VideoWriter reader = VideoReader('input.mp4', transform=ToTensor()) writer = VideoWriter('output.mp4', frame_rate=30) bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 绿背景 rec = [None] * 4 # 初始循环记忆(Recurrent States) downsample_ratio = 0.25 # 下采样比,根据视频调节 with torch.no_grad(): for src in DataLoader(reader): # 输入张量,RGB通道,范围为 0~1 fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # 将上一帧的记忆给下一帧 com = fgr * pha + bgr * (1 - pha) # 将前景合成到绿色背景 writer.write(com) # 输出帧 5 模型和 API 也可通过 TorchHub 快速载入。 # 加载模型 model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # 或 "resnet50" # 转换 API convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

推断文档里有对 downsample_ratio 参数,API 使用,和高阶使用的讲解。

训练和评估

请参照训练文档(英文)。

速度

速度用 inference_speed_test.py 测量以供参考。

GPUdTypeHD (1920x1080)4K (3840x2160)RTX 3090FP16172 FPS154 FPSRTX 2060 SuperFP16134 FPS108 FPSGTX 1080 TiFP32104 FPS74 FPS 注释1:HD 使用 downsample_ratio=0.25,4K 使用 downsample_ratio=0.125。 所有测试都使用 batch size 1 和 frame chunk 1。注释2:图灵架构之前的 GPU 不支持 FP16 推理,所以 GTX 1080 Ti 使用 FP32。注释3:我们只测量张量吞吐量(tensor throughput)。 提供的视频转换脚本会慢得多,因为它不使用硬件视频编码/解码,也没有在并行线程上完成张量传输。如果您有兴趣在 Python 中实现硬件视频编码/解码,请参考 PyNvCodec。 复现使用

知乎有个大佬把它分别用python和C++复现了RobustVideoMatting🔥2021 ONNXRuntime C++工程化记录-实现篇 - 知乎

python代码:

import cv2 import time import argparse import numpy as np import onnxruntime as ort def normalize(frame: np.ndarray) -> np.ndarray: """ Args: frame: BGR Returns: normalized 0~1 BCHW RGB """ img = frame.astype(np.float32).copy() / 255.0 img = img[:, :, ::-1] # RGB img = np.transpose(img, (2, 0, 1)) # (C,H,W) img = np.expand_dims(img, axis=0) # (B=1,C,H,W) return img def infer_rvm_frame(weight: str = "rvm_resnet50_fp32.onnx", img_path: str = "test.jpg", output_path: str = "test_onnx.jpg"): sess = ort.InferenceSession(f'./checkpoint/{weight}') print(f"Load checkpoint/{weight} done!") for _ in sess.get_inputs(): print("Input: ", _) for _ in sess.get_outputs(): print("Input: ", _) frame = cv2.imread(img_path) src = normalize(frame) rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4 # 必须用模型一样的 dtype downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32 bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1)) fgr, pha, *rec = sess.run([], { 'src': src, 'r1i': rec[0], 'r2i': rec[1], 'r3i': rec[2], 'r4i': rec[3], 'downsample_ratio': downsample_ratio }) merge_frame = fgr * pha + bgr * (1. - pha) # (1,3,H,W) merge_frame = merge_frame[0] * 255. # (3,H,W) merge_frame = merge_frame.astype(np.uint8) # RGB merge_frame = np.transpose(merge_frame, (1, 2, 0)) # (H,W,3) merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB) cv2.imwrite(output_path, merge_frame) print(f"infer done! saved {output_path}") def infer_rvm_video(weight: str = "rvm_resnet50_fp32.onnx", video_path: str = "./demo/1917.mp4", output_path: str = "./demo/1917_onnx.mp4"): sess = ort.InferenceSession(f'./checkpoint/{weight}') print(f"Load checkpoint/{weight} done!") for _ in sess.get_inputs(): print("Input: ", _) for _ in sess.get_outputs(): print("Input: ", _) # 读取视频 video_capture = cv2.VideoCapture(video_path) width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) print(f"Video Caputer: Height: {height}, Width: {width}, Frame Count: {frame_count}") # 写出视频 fps = 25 fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) print(f"Create Video Writer: {output_path}") i = 0 rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4 # 必须用模型一样的 dtype downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32 bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1)) print(f"Infer {video_path} start ...") while video_capture.isOpened(): success, frame = video_capture.read() if success: i += 1 src = normalize(frame) # src 张量是 [B, C, H, W] 形状,必须用模型一样的 dtype t1 = time.time() fgr, pha, *rec = sess.run([], { 'src': src, 'r1i': rec[0], 'r2i': rec[1], 'r3i': rec[2], 'r4i': rec[3], 'downsample_ratio': downsample_ratio }) t2 = time.time() print(f"Infer {i}/{frame_count} done! -> cost {(t2 - t1) * 1000} ms", end=" ") merge_frame = fgr * pha + bgr * (1. - pha) # (1,3,H,W) merge_frame = merge_frame[0] * 255. # (3,H,W) merge_frame = merge_frame.astype(np.uint8) # RGB merge_frame = np.transpose(merge_frame, (1, 2, 0)) # (H,W,3) merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB) merge_frame = cv2.resize(merge_frame, (width, height)) video_writer.write(merge_frame) print(f"write {i}/{frame_count} done.") else: print("can not read video! skip!") break video_capture.release() video_writer.release() print(f"Infer {video_path} done!") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mode", type=str, default="video") parser.add_argument("--weight", type=str, default="rvm_resnet50_fp32.onnx") parser.add_argument("--input", type=str, default="./demo/1917.mp4") parser.add_argument("--output", type=str, default="./demo/1917_onnx.mp4") args = parser.parse_args() if args.mode == "video": infer_rvm_video(weight=args.weight, video_path=args.input, output_path=args.output) else: infer_rvm_frame(weight=args.weight, img_path=args.input, output_path=args.output) """ rvm_resnet50_fp32.onnx rvm_mobilenetv3_fp32.onnx PYTHONPATH=. python3 ./inference_onnx.py --input ./demo/1917.mp4 --output ./demo/1917_onnx.mp4 PYTHONPATH=. python3 ./inference_onnx.py --mode img --input test.jpg --output test_onnx.jpg python inference_onnx.py --input ./demo/1917.mp4 --output ./demo/1917_onnx.mp4 """



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有