From 3b91ca3df66c42c7d4bfd8c38598827804b7e64e Mon Sep 17 00:00:00 2001 From: Yoshiki Obinata Date: Thu, 25 Jan 2024 17:27:22 +0900 Subject: [PATCH] Add web camera demo --- webcam.py | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 webcam.py diff --git a/webcam.py b/webcam.py new file mode 100644 index 00000000..38aabe49 --- /dev/null +++ b/webcam.py @@ -0,0 +1,82 @@ +import argparse +import cv2 +import numpy as np +import os +import torch +import torch.nn.functional as F +from torchvision.transforms import Compose +from tqdm import tqdm + +from depth_anything.dpt import DepthAnything +from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl']) + parser.add_argument('--frame_width', type=int, default=640) + parser.add_argument('--frame_height', type=int, default=480) + parser.add_argument('--fps', type=int, default=30) + + args = parser.parse_args() + + margin_width = 50 + caption_height = 60 + + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 1 + font_thickness = 2 + + DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{}14'.format(args.encoder)).to(DEVICE) + + total_params = sum(param.numel() for param in depth_anything.parameters()) + print('Total parameters: {:.2f}M'.format(total_params / 1e6)) + + depth_anything.eval() + + transform = Compose([ + Resize( + width=518, + height=518, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + cap = cv2.VideoCapture(0) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.frame_width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.frame_height) + cap.set(cv2.CAP_PROP_FPS, args.fps) + + while True: + ret, raw_image = cap.read() + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + h, w = image.shape[:2] + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0).to(DEVICE) + + with torch.no_grad(): + depth = depth_anything(image) + + depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0] + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + + depth = depth.cpu().numpy().astype(np.uint8) + depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) + + cv2.namedWindow('rgb', cv2.WINDOW_NORMAL) + cv2.imshow('rgb', raw_image) + + cv2.namedWindow('depth', cv2.WINDOW_NORMAL) + cv2.imshow('depth', depth_color) + + cv2.waitKey(1)