Skip to content

Commit a189795

Browse files
committed
cuda.amp fix
1 parent 4ac3c2a commit a189795

3 files changed

Lines changed: 10 additions & 58 deletions

File tree

bioencoder/core/utils.py

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def compute_embeddings(loader, model, scaler=None):
337337
Parameters:
338338
loader (torch.utils.data.DataLoader): DataLoader that provides images and labels.
339339
model (torch.nn.Module): Neural network model used to compute the embeddings.
340-
scaler (torch.cuda.amp.autocast): Autocast context manager used to perform mixed-precision training.
340+
scaler (torch.amp.autocast): Autocast context manager used to perform mixed-precision training.
341341
342342
Returns:
343343
tuple: A tuple containing:
@@ -350,7 +350,7 @@ def compute_embeddings(loader, model, scaler=None):
350350
for images, labels in loader:
351351
images = images.cuda()
352352
if scaler:
353-
with torch.cuda.amp.autocast():
353+
with torch.amp.autocast("cuda"):
354354
embed = model(images)
355355
else:
356356
embed = model(images)
@@ -379,7 +379,7 @@ def train_epoch_constructive(train_loader, model, criterion, optimizer, scaler,
379379
- model (torch.nn.Module): The model that will be trained.
380380
- criterion (torch.nn.Module): The loss function to be used for training.
381381
- optimizer (torch.optim.Optimizer): The optimization algorithm to be used for training.
382-
- scaler (torch.cuda.amp.GradScaler, optional): The scaler used for gradient scaling in case of mixed precision training.
382+
- scaler (torch.amp.GradScaler, optional): The scaler used for gradient scaling in case of mixed precision training.
383383
- ema (ExponentialMovingAverage, optional): If provided, the exponential moving average to be applied to the model's parameters.
384384
385385
Returns:
@@ -398,7 +398,7 @@ def train_epoch_constructive(train_loader, model, criterion, optimizer, scaler,
398398
bsz = labels.shape[0]
399399

400400
if scaler:
401-
with torch.cuda.amp.autocast():
401+
with torch.amp.autocast("cuda"):
402402
embed = model(images)
403403
if not loss_optimization:
404404
f1, f2 = torch.split(embed, [bsz, bsz], dim=0)
@@ -448,7 +448,7 @@ def validation_constructive(valid_loader, train_loader, model, scaler):
448448
valid_loader (torch.utils.data.DataLoader): DataLoader containing the validation data.
449449
train_loader (torch.utils.data.DataLoader): DataLoader containing the training data.
450450
model (torch.nn.Module): The model being trained.
451-
scaler (torch.cuda.amp.GradScaler): The scaler used for gradient scaling in case of mixed precision training.
451+
scaler (torch.amp.GradScaler): The scaler used for gradient scaling in case of mixed precision training.
452452
453453
Returns:
454454
acc_dict (dict): A dictionary containing the accuracy metrics, computed using the `AccuracyCalculator` class.
@@ -484,7 +484,7 @@ def train_epoch_ce(train_loader, model, criterion, optimizer, scaler, ema):
484484
model (torch.nn.Module): The model to be trained.
485485
criterion (torch.nn.Module): The loss function to be used for training.
486486
optimizer (torch.optim.Optimizer): The optimizer used to update model parameters.
487-
scaler (torch.cuda.amp.GradScaler): The scaler used for gradient scaling in case of mixed precision training.
487+
scaler (torch.amp.GradScaler): The scaler used for gradient scaling in case of mixed precision training.
488488
ema (Optional[torch.nn.Module]): The exponential moving average model.
489489
490490
Returns:
@@ -498,7 +498,7 @@ def train_epoch_ce(train_loader, model, criterion, optimizer, scaler, ema):
498498
data, target = data.cuda(), target.cuda()
499499
optimizer.zero_grad()
500500
if scaler:
501-
with torch.cuda.amp.autocast():
501+
with torch.amp.autocast("cuda"):
502502
output = model(data)
503503
loss = criterion(output, target)
504504
train_loss.append(loss.item())
@@ -533,7 +533,7 @@ def validation_ce(model, criterion, valid_loader, scaler):
533533
with torch.no_grad():
534534
data, target = data.cuda(), target.cuda()
535535
if scaler:
536-
with torch.cuda.amp.autocast():
536+
with torch.amp.autocast("cuda"):
537537
output = model(data)
538538
if criterion:
539539
loss = criterion(output, target)
@@ -562,53 +562,6 @@ def validation_ce(model, criterion, valid_loader, scaler):
562562
return metrics
563563

564564

565-
# def validation_ce(model, criterion, valid_loader, scaler):
566-
# """
567-
# Validates the given model with cross entropy loss and calculates several evaluation metrics such as accuracy, F1 scores and F1 score macro.
568-
569-
# Parameters:
570-
# model (torch.nn.Module): The model to be validated.
571-
# criterion (torch.nn.modules.loss._Loss): The criterion to be used for validation, which is cross entropy loss in this case.
572-
# valid_loader (torch.utils.data.DataLoader): The data loader for validation dataset.
573-
# scaler (torch.cuda.amp.autocast.Autocast): Optional scaler for using automatic mixed precision (AMP).
574-
575-
# Returns:
576-
# dict: A dictionary containing the validation loss, accuracy, F1 scores, and F1 score macro.
577-
578-
# """
579-
# model.eval()
580-
# val_loss = []
581-
# y_pred, y_true = [], []
582-
583-
# for data, target in valid_loader:
584-
# with torch.no_grad():
585-
# data, target = data.cuda(), target.cuda()
586-
# if scaler:
587-
# with torch.cuda.amp.autocast():
588-
# output = model(data)
589-
# else:
590-
# output = model(data)
591-
592-
# if criterion:
593-
# loss = criterion(output, target)
594-
# val_loss.append(loss.item())
595-
596-
# pred = output.argmax(dim=1)
597-
# y_pred.extend(pred.cpu().numpy())
598-
# y_true.extend(target.cpu().numpy())
599-
600-
# del data, target, output
601-
# torch.cuda.empty_cache()
602-
603-
# valid_loss = np.mean(val_loss)
604-
# f1_scores = f1_score(y_true, y_pred, average=None)
605-
# f1_score_macro = f1_score(y_true, y_pred, average='macro')
606-
# acc_score = accuracy_score(y_true, y_pred)
607-
608-
# metrics = {"loss": valid_loss, "accuracy": acc_score, "f1_scores": f1_scores, 'f1_score_macro': f1_score_macro}
609-
# return metrics
610-
611-
612565
def copy_parameters_from_model(model):
613566
"""
614567
Copy parameters from a PyTorch model.

bioencoder/scripts/swa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def swa(
7171
os.remove(os.path.join(weights_dir, "swa"))
7272

7373
## scaler
74-
scaler = torch.cuda.amp.GradScaler()
74+
scaler = torch.amp.GradScaler("cuda")
7575
if not amp:
7676
scaler = None
7777
utils.set_seed()

bioencoder/scripts/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def train(
8989
aug_sample_n = aug_config.get("sample_n", 5)
9090
aug_sample_seed = aug_config.get("sample_seed", 42)
9191

92-
9392
## manage directories and paths
9493
data_dir = os.path.join(root_dir, "data", run_name)
9594
log_dir = os.path.join(root_dir, "logs", run_name, stage)
@@ -166,7 +165,7 @@ def train(
166165
logger.info(f"Hyperparameters:\n{pretty_repr(hyperparams)}")
167166

168167
## scaler
169-
scaler = torch.cuda.amp.GradScaler()
168+
scaler = torch.amp.GradScaler("cuda")
170169
if not amp:
171170
scaler = None
172171

0 commit comments

Comments
 (0)