# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
BATCH_SIZE = 8
NUM_EPOCHS = 5
# Directories
TRAIN_IMG_DIR = "data/train/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val/"
VAL_MASK_DIR = "data/val_masks/"
# Transforms
train_transform = A.Compose([
A.Resize(1280, 1918), # Force specific dimensions
A.Rotate(limit=35, p=0.5),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0
),
ToTensorV2()
])
val_transform = A.Compose([
A.Resize(1280, 1918), # Force specific dimensions
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0
),
ToTensorV2()
])
def train_one_epoch(loader, model, optimizer, loss_fn, device):
model.train()
loop = tqdm(loader)
running_loss = 0.0
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device)
targets = targets.float().unsqueeze(1).to(device)
# forward
predictions = model(data)
loss = loss_fn(predictions, targets)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
loop.set_postfix(loss=loss.item())
return running_loss / len(loader)
def evaluate(loader, model, device):
model.eval()
dice_score = 0
num_correct = 0
num_pixels = 0
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device).unsqueeze(1)
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
num_correct += (preds == y).sum()
num_pixels += torch.numel(preds)
dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
accuracy = (num_correct / num_pixels * 100).item()
dice = (dice_score / len(loader)).item()
return accuracy, dice
def plot_predictions(model, val_loader, device, num_samples=3):
model.eval()
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
with torch.no_grad():
for idx, (x, y) in enumerate(val_loader):
if idx >= num_samples:
break
x = x.to(device)
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
# Convert to numpy and denormalize
img = x[0].cpu().numpy().transpose(1, 2, 0)
img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
img = np.clip(img * 255, 0, 255).astype(np.uint8)
mask = y[0].cpu().numpy() * 255
pred = preds[0].cpu().numpy()[0] * 255
axes[idx, 0].imshow(img)
axes[idx, 0].set_title('Original Image')
axes[idx, 0].axis('off')
axes[idx, 1].imshow(mask, cmap='gray')
axes[idx, 1].set_title('Ground Truth')
axes[idx, 1].axis('off')
axes[idx, 2].imshow(pred, cmap='gray')
axes[idx, 2].set_title('Prediction')
axes[idx, 2].axis('off')
plt.tight_layout()
plt.show()
def plot_training_history(train_losses, val_accuracies, val_dices):
epochs = range(1, len(train_losses) + 1)
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses, 'b-')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(1, 3, 2)
plt.plot(epochs, val_accuracies, 'g-')
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.subplot(1, 3, 3)
plt.plot(epochs, val_dices, 'r-')
plt.title('Validation Dice Score')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.tight_layout()
plt.show()
def main():
# Create data loaders
train_ds = CarvanaDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_transform)
val_ds = CarvanaDataset(VAL_IMG_DIR, VAL_MASK_DIR, val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
# Initialize model, loss, optimizer
model = ResNetSegmentation().to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Training history
train_losses = []
val_accuracies = []
val_dices = []
# Training loop
print("Starting training...")
for epoch in range(NUM_EPOCHS):
print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
# Train
train_loss = train_one_epoch(train_loader, model, optimizer, loss_fn, DEVICE)
train_losses.append(train_loss)
# Evaluate
accuracy, dice = evaluate(val_loader, model, DEVICE)
val_accuracies.append(accuracy)
val_dices.append(dice)
print(f"Train Loss: {train_loss:.4f}")
print(f"Val Accuracy: {accuracy:.2f}%")
print(f"Val Dice Score: {dice:.4f}")
# Plot some predictions at the end of each epoch
plot_predictions(model, val_loader, DEVICE)
# Plot training history
plot_training_history(train_losses, val_accuracies, val_dices)
# Save final model
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_losses': train_losses,
'val_accuracies': val_accuracies,
'val_dices': val_dices
}, 'resnet18_segmentation_final.pth')