PKaushik commited on
Commit
845f762
1 Parent(s): 9e48a2a
Files changed (1) hide show
  1. inferer.py +238 -0
inferer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import math
4
+ import os.path as osp
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image, ImageFont
11
+
12
+ from yolov6.data.data_augment import letterbox
13
+ from yolov6.layers.common import DetectBackend
14
+ from yolov6.utils.events import LOGGER, load_yaml
15
+ from yolov6.utils.nms import non_max_suppression
16
+
17
+
18
+ class Inferer:
19
+ def __init__(self, model_id, device="cpu", yaml="coco.yaml", img_size=640, half=False):
20
+ self.__dict__.update(locals())
21
+
22
+ # Init model
23
+ self.img_size = img_size
24
+ cuda = device != "cpu" and torch.cuda.is_available()
25
+ self.device = torch.device("cuda:0" if cuda else "cpu")
26
+ self.model = DetectBackend(hf_hub_download(model_id, "model.pt"), device=self.device)
27
+ self.stride = self.model.stride
28
+ self.class_names = load_yaml(yaml)["names"]
29
+ self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
30
+
31
+ # Half precision
32
+ if half & (self.device.type != "cpu"):
33
+ self.model.model.half()
34
+ else:
35
+ self.model.model.float()
36
+ half = False
37
+
38
+ if self.device.type != "cpu":
39
+ self.model(
40
+ torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))
41
+ ) # warmup
42
+
43
+ # Switch model to deploy status
44
+ self.model_switch(self.model, self.img_size)
45
+
46
+ def model_switch(self, model, img_size):
47
+ """Model switch to deploy status"""
48
+ from yolov6.layers.common import RepVGGBlock
49
+
50
+ for layer in model.modules():
51
+ if isinstance(layer, RepVGGBlock):
52
+ layer.switch_to_deploy()
53
+
54
+ LOGGER.info("Switch model to deploy modality.")
55
+
56
+ def __call__(
57
+ self,
58
+ path_or_image,
59
+ conf_thres=0.25,
60
+ iou_thres=0.45,
61
+ classes=None,
62
+ agnostic_nms=False,
63
+ max_det=1000,
64
+ hide_labels=False,
65
+ hide_conf=False,
66
+ ):
67
+ """Model Inference and results visualization"""
68
+
69
+ img, img_src = self.precess_image(path_or_image, self.img_size, self.stride, self.half)
70
+ img = img.to(self.device)
71
+ if len(img.shape) == 3:
72
+ img = img[None]
73
+ # expand for batch dim
74
+ pred_results = self.model(img)
75
+ det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
76
+
77
+ gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
78
+ img_ori = img_src
79
+
80
+ # check image and font
81
+ assert (
82
+ img_ori.data.contiguous
83
+ ), "Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im)."
84
+ self.font_check()
85
+
86
+ if len(det):
87
+ det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
88
+
89
+ for *xyxy, conf, cls in reversed(det):
90
+ class_num = int(cls) # integer class
91
+ label = (
92
+ None
93
+ if hide_labels
94
+ else (self.class_names[class_num] if hide_conf else f"{self.class_names[class_num]} {conf:.2f}")
95
+ )
96
+
97
+ self.plot_box_and_label(
98
+ img_ori,
99
+ max(round(sum(img_ori.shape) / 2 * 0.003), 2),
100
+ xyxy,
101
+ label,
102
+ color=self.generate_colors(class_num, True),
103
+ )
104
+
105
+ img_src = np.asarray(img_ori)
106
+
107
+ return img_src
108
+
109
+ @staticmethod
110
+ def precess_image(path_or_image, img_size, stride, half):
111
+ """Process image before image inference."""
112
+ if isinstance(path_or_image, str):
113
+ try:
114
+ img_src = cv2.imread(path_or_image)
115
+ assert img_src is not None, f"Invalid image: {path_or_image}"
116
+ except Exception as e:
117
+ LOGGER.warning(e)
118
+ elif isinstance(path_or_image, np.ndarray):
119
+ img_src = path_or_image
120
+ elif isinstance(path_or_image, Image.Image):
121
+ img_src = np.array(path_or_image)
122
+
123
+ image = letterbox(img_src, img_size, stride=stride)[0]
124
+
125
+ # Convert
126
+ image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
127
+ image = torch.from_numpy(np.ascontiguousarray(image))
128
+ image = image.half() if half else image.float() # uint8 to fp16/32
129
+ image /= 255 # 0 - 255 to 0.0 - 1.0
130
+
131
+ return image, img_src
132
+
133
+ @staticmethod
134
+ def rescale(ori_shape, boxes, target_shape):
135
+ """Rescale the output to the original image shape"""
136
+ ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
137
+ padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
138
+
139
+ boxes[:, [0, 2]] -= padding[0]
140
+ boxes[:, [1, 3]] -= padding[1]
141
+ boxes[:, :4] /= ratio
142
+
143
+ boxes[:, 0].clamp_(0, target_shape[1]) # x1
144
+ boxes[:, 1].clamp_(0, target_shape[0]) # y1
145
+ boxes[:, 2].clamp_(0, target_shape[1]) # x2
146
+ boxes[:, 3].clamp_(0, target_shape[0]) # y2
147
+
148
+ return boxes
149
+
150
+ def check_img_size(self, img_size, s=32, floor=0):
151
+ """Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
152
+ if isinstance(img_size, int): # integer i.e. img_size=640
153
+ new_size = max(self.make_divisible(img_size, int(s)), floor)
154
+ elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
155
+ new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
156
+ else:
157
+ raise Exception(f"Unsupported type of img_size: {type(img_size)}")
158
+
159
+ if new_size != img_size:
160
+ print(f"WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}")
161
+ return new_size if isinstance(img_size, list) else [new_size] * 2
162
+
163
+ def make_divisible(self, x, divisor):
164
+ # Upward revision the value x to make it evenly divisible by the divisor.
165
+ return math.ceil(x / divisor) * divisor
166
+
167
+ @staticmethod
168
+ def plot_box_and_label(image, lw, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)):
169
+ # Add one xyxy box to image with label
170
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
171
+ cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
172
+ if label:
173
+ tf = max(lw - 1, 1) # font thickness
174
+ w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
175
+ outside = p1[1] - h - 3 >= 0 # label fits outside box
176
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
177
+ cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
178
+ cv2.putText(
179
+ image,
180
+ label,
181
+ (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
182
+ 0,
183
+ lw / 3,
184
+ txt_color,
185
+ thickness=tf,
186
+ lineType=cv2.LINE_AA,
187
+ )
188
+
189
+ @staticmethod
190
+ def font_check(font="./yolov6/utils/Arial.ttf", size=10):
191
+ # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
192
+ assert osp.exists(font), f"font path not exists: {font}"
193
+ try:
194
+ return ImageFont.truetype(str(font) if font.exists() else font.name, size)
195
+ except Exception as e: # download if missing
196
+ return ImageFont.truetype(str(font), size)
197
+
198
+ @staticmethod
199
+ def box_convert(x):
200
+ # Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
201
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
202
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
203
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
204
+ y[:, 2] = x[:, 2] - x[:, 0] # width
205
+ y[:, 3] = x[:, 3] - x[:, 1] # height
206
+ return y
207
+
208
+ @staticmethod
209
+ def generate_colors(i, bgr=False):
210
+ hex = (
211
+ "FF3838",
212
+ "FF9D97",
213
+ "FF701F",
214
+ "FFB21D",
215
+ "CFD231",
216
+ "48F90A",
217
+ "92CC17",
218
+ "3DDB86",
219
+ "1A9334",
220
+ "00D4BB",
221
+ "2C99A8",
222
+ "00C2FF",
223
+ "344593",
224
+ "6473FF",
225
+ "0018EC",
226
+ "8438FF",
227
+ "520085",
228
+ "CB38FF",
229
+ "FF95C8",
230
+ "FF37C7",
231
+ )
232
+ palette = []
233
+ for iter in hex:
234
+ h = "#" + iter
235
+ palette.append(tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)))
236
+ num = len(palette)
237
+ color = palette[int(i) % num]
238
+ return (color[2], color[1], color[0]) if bgr else color