Skip to content

Commit d56d0b7

Browse files
committed
add visualze
1 parent 276d625 commit d56d0b7

3 files changed

Lines changed: 88 additions & 0 deletions

File tree

tensorlayer/app/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
# -*- coding: utf-8 -*-
33

44
from .computer_vision_object_detection import *
5+
from .computer_vision import *

tensorlayer/app/computer_vision.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from tensorlayer.app import YOLOv4
5+
import numpy as np
6+
import tensorflow as tf
7+
8+
9+
class object_detection(object):
10+
11+
def __init__(self, model_name='yolo4-mscoco'):
12+
self.model_name = model_name
13+
if self.model_name == 'yolo4-mscoco':
14+
self.model = YOLOv4(NUM_CLASS=80, pretrained=True)
15+
else:
16+
raise ("The model does not support.")
17+
18+
def __call__(self, input_data):
19+
if self.model_name == 'yolo4-mscoco':
20+
image_data = input_data / 255.
21+
images_data = []
22+
for i in range(1):
23+
images_data.append(image_data)
24+
images_data = np.asarray(images_data).astype(np.float32)
25+
batch_data = tf.constant(images_data)
26+
output = self.model(batch_data, is_train=False)
27+
else:
28+
raise NotImplementedError
29+
30+
return output
31+
32+
def __repr__(self):
33+
s = ('{classname}(model_name={model_name}, model_structure={model}')
34+
s += ')'
35+
return s.format(classname=self.__class__.__name__, **self.__dict__)

tensorlayer/app/visualize.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from tensorlayer.app import get_anchors, decode, filter_boxes, draw_bbox
5+
import numpy as np
6+
import tensorflow as tf
7+
import cv2
8+
from PIL import Image
9+
10+
11+
def yolo4_visualize(original_image, feature_maps):
12+
STRIDES = [8, 16, 32]
13+
ANCHORS = get_anchors([12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401])
14+
NUM_CLASS = 80
15+
XYSCALE = [1.2, 1.1, 1.05]
16+
iou_threshold = 0.45
17+
score_threshold = 0.25
18+
19+
bbox_tensors = []
20+
prob_tensors = []
21+
score_thres = 0.2
22+
for i, fm in enumerate(feature_maps):
23+
if i == 0:
24+
output_tensors = decode(fm, 416 // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
25+
elif i == 1:
26+
output_tensors = decode(fm, 416 // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
27+
else:
28+
output_tensors = decode(fm, 416 // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
29+
bbox_tensors.append(output_tensors[0])
30+
prob_tensors.append(output_tensors[1])
31+
pred_bbox = tf.concat(bbox_tensors, axis=1)
32+
pred_prob = tf.concat(prob_tensors, axis=1)
33+
boxes, pred_conf = filter_boxes(
34+
pred_bbox, pred_prob, score_threshold=score_thres, input_shape=tf.constant([416, 416])
35+
)
36+
pred = {'concat': tf.concat([boxes, pred_conf], axis=-1)}
37+
38+
for key, value in pred.items():
39+
boxes = value[:, :, 0:4]
40+
pred_conf = value[:, :, 4:]
41+
42+
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
43+
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
44+
scores=tf.reshape(pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
45+
max_output_size_per_class=50, max_total_size=50, iou_threshold=iou_threshold, score_threshold=score_threshold
46+
)
47+
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
48+
image = draw_bbox(original_image, pred_bbox)
49+
image = Image.fromarray(image.astype(np.uint8))
50+
image.show()
51+
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
52+
cv2.imwrite('result.png', image)

0 commit comments

Comments
 (0)