@@ -75,7 +75,12 @@ def inference(
7575 ckpt_pretrained = checkpoint_path
7676 else :
7777 ckpt_pretrained = os .path .join (config .root_dir , "weights" , run_name , stage , checkpoint )
78-
78+
79+ ## load from config
80+ if root_dir and run_name :
81+ train_dir = os .path .join (root_dir ,"data" , run_name , "train" )
82+ labels_sorted = ImageFolder (root = train_dir ).classes
83+
7984 ## set random seed
8085 utils .set_seed ()
8186
@@ -99,26 +104,14 @@ def inference(
99104
100105 ## set to eval
101106 model .eval ()
102-
103- ## get labels
104- train_dir = os .path .join (root_dir ,"data" , run_name , "train" )
105- labels_sorted = ImageFolder (root = train_dir ).classes
106107
107108 ## load file
108109 if isinstance (image , str ):
109- if os .path .isfile (image ):
110- image = Image .open (image )
111- image = np .asarray (image )
112- else :
113- print ("File does not exist" )
114- return
115- elif isinstance (image , (np .ndarray , np .generic )):
116- print ("image shape:" + str (image .shape ))
117- # Input is already a numpy array or an instance of np.generic (which np.ndarray inherits from)
118- pass
119- else :
120- print ("Wrong format - need either image path or array type" )
121- return
110+ if not os .path .isfile (image ):
111+ raise FileNotFoundError (f"File does not exist: { image } " )
112+ image = np .asarray (Image .open (image ))
113+ elif not isinstance (image , (np .ndarray , np .generic )):
114+ raise TypeError ("Input must be either an image path (str) or a NumPy array." )
122115
123116 ## transform image and move to GPU
124117 image = transform (image = image )["image" ]
0 commit comments