|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +from tensorlayer.app import get_anchors, decode, filter_boxes, draw_bbox |
| 5 | +import numpy as np |
| 6 | +import tensorflow as tf |
| 7 | +import cv2 |
| 8 | +from PIL import Image |
| 9 | + |
| 10 | + |
| 11 | +def yolo4_visualize(original_image, feature_maps): |
| 12 | + STRIDES = [8, 16, 32] |
| 13 | + ANCHORS = get_anchors([12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]) |
| 14 | + NUM_CLASS = 80 |
| 15 | + XYSCALE = [1.2, 1.1, 1.05] |
| 16 | + iou_threshold = 0.45 |
| 17 | + score_threshold = 0.25 |
| 18 | + |
| 19 | + bbox_tensors = [] |
| 20 | + prob_tensors = [] |
| 21 | + score_thres = 0.2 |
| 22 | + for i, fm in enumerate(feature_maps): |
| 23 | + if i == 0: |
| 24 | + output_tensors = decode(fm, 416 // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE) |
| 25 | + elif i == 1: |
| 26 | + output_tensors = decode(fm, 416 // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE) |
| 27 | + else: |
| 28 | + output_tensors = decode(fm, 416 // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE) |
| 29 | + bbox_tensors.append(output_tensors[0]) |
| 30 | + prob_tensors.append(output_tensors[1]) |
| 31 | + pred_bbox = tf.concat(bbox_tensors, axis=1) |
| 32 | + pred_prob = tf.concat(prob_tensors, axis=1) |
| 33 | + boxes, pred_conf = filter_boxes( |
| 34 | + pred_bbox, pred_prob, score_threshold=score_thres, input_shape=tf.constant([416, 416]) |
| 35 | + ) |
| 36 | + pred = {'concat': tf.concat([boxes, pred_conf], axis=-1)} |
| 37 | + |
| 38 | + for key, value in pred.items(): |
| 39 | + boxes = value[:, :, 0:4] |
| 40 | + pred_conf = value[:, :, 4:] |
| 41 | + |
| 42 | + boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression( |
| 43 | + boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)), |
| 44 | + scores=tf.reshape(pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])), |
| 45 | + max_output_size_per_class=50, max_total_size=50, iou_threshold=iou_threshold, score_threshold=score_threshold |
| 46 | + ) |
| 47 | + pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()] |
| 48 | + image = draw_bbox(original_image, pred_bbox) |
| 49 | + image = Image.fromarray(image.astype(np.uint8)) |
| 50 | + image.show() |
| 51 | + image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) |
| 52 | + cv2.imwrite('result.png', image) |
0 commit comments