22# -*- coding: utf-8 -*-
33
44from tensorlayer .app import YOLOv4 , get_anchors , decode , filter_boxes
5+ from tensorlayer .app import CGCNN
56import numpy as np
67import tensorflow as tf
78from tensorlayer import logging
89import cv2
910
1011
1112class object_detection (object ):
13+ """Model encapsulation.
14+
15+ Parameters
16+ ----------
17+ model_name : str
18+ Choose the model to inference.
19+
20+ Methods
21+ ---------
22+ __init__()
23+ Initializing the model.
24+ __call__()
25+ (1)Formatted input and output. (2)Inference model.
26+ list()
27+ Abstract method. Return available a list of model_name.
28+
29+ Examples
30+ ---------
31+ Object Detection detection MSCOCO with YOLOv4, see `tutorial_object_detection_yolov4.py
32+ <https://github.com/tensorlayer/tensorlayer/blob/master/example/app_tutorials/tutorial_object_detection_yolov4.py>`__
33+ With TensorLayer
34+
35+ >>> # get the whole model
36+ >>> net = tl.app.computer_vision.object_detection('yolo4-mscoco')
37+ >>> # use for inferencing
38+ >>> output = net(img)
39+ """
1240
1341 def __init__ (self , model_name = 'yolo4-mscoco' ):
1442 self .model_name = model_name
1543 if self .model_name == 'yolo4-mscoco' :
1644 self .model = YOLOv4 (NUM_CLASS = 80 , pretrained = True )
45+ elif self .model_name == 'lcn' :
46+ self .model = CGCNN (pretrained = True )
1747 else :
1848 raise ("The model does not support." )
1949
@@ -23,6 +53,8 @@ def __call__(self, input_data):
2353 feature_maps = self .model (batch_data , is_train = False )
2454 pred_bbox = yolo4_output_processing (feature_maps )
2555 output = result_to_json (input_data , pred_bbox )
56+ elif self .model_name == 'lcn' :
57+ output = self .model (input_data )
2658 else :
2759 raise NotImplementedError
2860
@@ -35,7 +67,7 @@ def __repr__(self):
3567
3668 @property
3769 def list (self ):
38- logging .info ("The model name list: yolov4-mscoco" )
70+ logging .info ("The model name list: ' yolov4-mscoco', 'lcn' " )
3971
4072
4173def yolo4_input_processing (original_image ):
0 commit comments