11#! /usr/bin/python
22# -*- coding: utf-8 -*-
33
4- from tensorlayer .app import YOLOv4
4+ from tensorlayer .app import YOLOv4 , get_anchors , decode , filter_boxes
55import numpy as np
66import tensorflow as tf
7+ from tensorlayer import logging
8+ import cv2
79
810
911class object_detection (object ):
@@ -17,19 +19,70 @@ def __init__(self, model_name='yolo4-mscoco'):
1719
1820 def __call__ (self , input_data ):
1921 if self .model_name == 'yolo4-mscoco' :
20- image_data = input_data / 255.
21- images_data = []
22- for i in range (1 ):
23- images_data .append (image_data )
24- images_data = np .asarray (images_data ).astype (np .float32 )
25- batch_data = tf .constant (images_data )
26- output = self .model (batch_data , is_train = False )
22+ batch_data = yolo4_input_processing (input_data )
23+ feature_maps = self .model (batch_data , is_train = False )
24+ output = yolo4_output_processing (feature_maps )
2725 else :
2826 raise NotImplementedError
2927
3028 return output
3129
3230 def __repr__ (self ):
33- s = ('{classname} (model_name={model_name}, model_structure={model}' )
31+ s = ('(model_name={model_name}, model_structure={model}' )
3432 s += ')'
3533 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
34+
35+ @property
36+ def list (self ):
37+ logging .info ("The model name list: yolov4-mscoco" )
38+
39+
40+ def yolo4_input_processing (original_image ):
41+ image_data = cv2 .resize (original_image , (416 , 416 ))
42+ image_data = image_data / 255.
43+ images_data = []
44+ for i in range (1 ):
45+ images_data .append (image_data )
46+ images_data = np .asarray (images_data ).astype (np .float32 )
47+ batch_data = tf .constant (images_data )
48+ return batch_data
49+
50+
51+ def yolo4_output_processing (feature_maps ):
52+ STRIDES = [8 , 16 , 32 ]
53+ ANCHORS = get_anchors ([12 , 16 , 19 , 36 , 40 , 28 , 36 , 75 , 76 , 55 , 72 , 146 , 142 , 110 , 192 , 243 , 459 , 401 ])
54+ NUM_CLASS = 80
55+ XYSCALE = [1.2 , 1.1 , 1.05 ]
56+ iou_threshold = 0.45
57+ score_threshold = 0.25
58+
59+ bbox_tensors = []
60+ prob_tensors = []
61+ score_thres = 0.2
62+ for i , fm in enumerate (feature_maps ):
63+ if i == 0 :
64+ output_tensors = decode (fm , 416 // 8 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
65+ elif i == 1 :
66+ output_tensors = decode (fm , 416 // 16 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
67+ else :
68+ output_tensors = decode (fm , 416 // 32 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
69+ bbox_tensors .append (output_tensors [0 ])
70+ prob_tensors .append (output_tensors [1 ])
71+ pred_bbox = tf .concat (bbox_tensors , axis = 1 )
72+ pred_prob = tf .concat (prob_tensors , axis = 1 )
73+ boxes , pred_conf = filter_boxes (
74+ pred_bbox , pred_prob , score_threshold = score_thres , input_shape = tf .constant ([416 , 416 ])
75+ )
76+ pred = {'concat' : tf .concat ([boxes , pred_conf ], axis = - 1 )}
77+
78+ for key , value in pred .items ():
79+ boxes = value [:, :, 0 :4 ]
80+ pred_conf = value [:, :, 4 :]
81+
82+ boxes , scores , classes , valid_detections = tf .image .combined_non_max_suppression (
83+ boxes = tf .reshape (boxes , (tf .shape (boxes )[0 ], - 1 , 1 , 4 )),
84+ scores = tf .reshape (pred_conf , (tf .shape (pred_conf )[0 ], - 1 , tf .shape (pred_conf )[- 1 ])),
85+ max_output_size_per_class = 50 , max_total_size = 50 , iou_threshold = iou_threshold , score_threshold = score_threshold
86+ )
87+ output = [boxes .numpy (), scores .numpy (), classes .numpy (), valid_detections .numpy ()]
88+ return output
0 commit comments