Skip to content

Commit b56ef5c

Browse files
committed
fix hbar on headless / inference wo. config
1 parent 5872b6d commit b56ef5c

2 files changed

Lines changed: 18 additions & 20 deletions

File tree

bioencoder/core/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,14 @@ def set_seed(seed=42):
108108

109109
return seed
110110

111-
112111
def pprint_fill_hbar(message, symbol="-", ret=True):
113-
terminal_width = os.get_terminal_size()[0] - len("%Y-%m-%d %H:%M:%S")
112+
try:
113+
# Try to get the terminal width
114+
terminal_width = os.get_terminal_size()[0] - len("%Y-%m-%d %H:%M:%S")
115+
except OSError:
116+
# Fallback width for headless environments
117+
terminal_width = 80 # Default width if terminal size can't be determined
118+
114119
message_length = len(message)
115120

116121
if message_length >= terminal_width:

bioencoder/scripts/inference.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ def inference(
7575
ckpt_pretrained = checkpoint_path
7676
else:
7777
ckpt_pretrained = os.path.join(config.root_dir, "weights", run_name, stage, checkpoint)
78-
78+
79+
## load from config
80+
if root_dir and run_name:
81+
train_dir = os.path.join(root_dir,"data", run_name, "train")
82+
labels_sorted = ImageFolder(root=train_dir).classes
83+
7984
## set random seed
8085
utils.set_seed()
8186

@@ -99,26 +104,14 @@ def inference(
99104

100105
## set to eval
101106
model.eval()
102-
103-
## get labels
104-
train_dir = os.path.join(root_dir,"data", run_name, "train")
105-
labels_sorted = ImageFolder(root=train_dir).classes
106107

107108
## load file
108109
if isinstance(image, str):
109-
if os.path.isfile(image):
110-
image = Image.open(image)
111-
image = np.asarray(image)
112-
else:
113-
print("File does not exist")
114-
return
115-
elif isinstance(image, (np.ndarray, np.generic)):
116-
print("image shape:" + str(image.shape))
117-
# Input is already a numpy array or an instance of np.generic (which np.ndarray inherits from)
118-
pass
119-
else:
120-
print("Wrong format - need either image path or array type")
121-
return
110+
if not os.path.isfile(image):
111+
raise FileNotFoundError(f"File does not exist: {image}")
112+
image = np.asarray(Image.open(image))
113+
elif not isinstance(image, (np.ndarray, np.generic)):
114+
raise TypeError("Input must be either an image path (str) or a NumPy array.")
122115

123116
## transform image and move to GPU
124117
image = transform(image=image)["image"]

0 commit comments

Comments
 (0)