diff --git a/rewrite_training.py b/rewrite_training.py index 8e46a5e..4e78d91 100644 --- a/rewrite_training.py +++ b/rewrite_training.py @@ -4,22 +4,30 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, random_split from concurrent.futures import ProcessPoolExecutor from tqdm import tqdm # --- CONFIGURATION --- BASE_PATH = "asl_kaggle" TARGET_FRAMES = 22 -# Hand landmarks + Lip landmarks (approximate indices for high-value face points) LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] HANDS = list(range(468, 543)) SELECTED_INDICES = LIPS + HANDS -NUM_FEATS = len(SELECTED_INDICES) * 3 # X, Y, Z for each selected point +NUM_FEATS = len(SELECTED_INDICES) * 3 + +# Training hyperparameters +BATCH_SIZE = 32 +EPOCHS = 50 +LEARNING_RATE = 0.001 +TRAIN_SPLIT = 0.8 +CHECKPOINT_DIR = "checkpoints" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") -# --- DATA PROCESSING --- +# --- DATA PROCESSING --- def load_kaggle_metadata(base_path): return pl.read_csv(os.path.join(base_path, "train.csv")) @@ -28,7 +36,6 @@ def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): parquet_path = os.path.join(base_path, path) df = pl.read_parquet(parquet_path) - # 1. Spatial Normalization (Nose Anchor) anchors = ( df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) .select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")]) @@ -44,21 +51,28 @@ def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): .sort(["frame", "type", "landmark_index"]) ) - # 2. Reshape & Feature Selection - # Get unique frames and total landmarks (543) raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3) - - # Slice to keep only Hands and Lips reduced_tensor = raw_tensor[:, SELECTED_INDICES, :] - # 3. Temporal Normalization (Resample to fixed frame count) curr_len = reduced_tensor.shape[0] indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int) return reduced_tensor[indices] -# --- MODEL ARCHITECTURE --- +# --- DATASET CLASS --- +class ASLDataset(Dataset): + def __init__(self, tensors, labels): + self.tensors = tensors + self.labels = labels + def __len__(self): + return len(self.tensors) + + def __getitem__(self, idx): + return self.tensors[idx], self.labels[idx] + + +# --- MODEL ARCHITECTURE --- class ASLClassifier(nn.Module): def __init__(self, num_classes, target_frames=TARGET_FRAMES, num_feats=NUM_FEATS): super().__init__() @@ -71,42 +85,153 @@ class ASLClassifier(nn.Module): self.fc = nn.Linear(512, num_classes) def forward(self, x): - # x: (Batch, Frames, Selected_Landmarks, 3) - x = x.view(x.shape[0], x.shape[1], -1) # Flatten landmarks/coords - x = x.transpose(1, 2) # (Batch, Features, Time) - + x = x.view(x.shape[0], x.shape[1], -1) + x = x.transpose(1, 2) x = F.relu(self.bn1(self.conv1(x))) x = self.pool(x) x = F.relu(self.bn2(self.conv2(x))) x = self.pool(x) - x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) x = self.dropout(x) return self.fc(x) -# --- EXECUTION --- +# --- TRAINING FUNCTIONS --- +def train_epoch(model, dataloader, criterion, optimizer, device): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + progress_bar = tqdm(dataloader, desc="Training") + for inputs, labels in progress_bar: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + progress_bar.set_postfix({ + 'loss': running_loss / (progress_bar.n + 1), + 'acc': 100 * correct / total + }) + + epoch_loss = running_loss / len(dataloader) + epoch_acc = 100 * correct / total + return epoch_loss, epoch_acc + + +def validate(model, dataloader, criterion, device): + model.eval() + running_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in tqdm(dataloader, desc="Validation"): + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + loss = criterion(outputs, labels) + + running_loss += loss.item() + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + val_loss = running_loss / len(dataloader) + val_acc = 100 * correct / total + return val_loss, val_acc + + +def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, val_acc, checkpoint_dir): + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': train_loss, + 'val_loss': val_loss, + 'val_acc': val_acc, + } + path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt') + torch.save(checkpoint, path) + print(f"Checkpoint saved: {path}") + + +# --- EXECUTION --- if __name__ == "__main__": + # Load metadata asl_data = load_kaggle_metadata(BASE_PATH) - # Optimization: Process 100 samples to get a feel for the shape/speed - # Using multiprocessing to avoid the slow single-thread loop - paths = asl_data["path"].to_list() + # Create label mapping + unique_signs = sorted(asl_data["sign"].unique().to_list()) + label_to_idx = {sign: idx for idx, sign in enumerate(unique_signs)} + labels = torch.tensor([label_to_idx[sign] for sign in asl_data["sign"].to_list()]) + print(f"Number of classes: {len(unique_signs)}") + + # Process data in parallel + paths = asl_data["path"].to_list() print(f"Processing {len(paths)} files in parallel...") + with ProcessPoolExecutor() as executor: results = list(tqdm(executor.map(load_and_preprocess, paths), total=len(paths))) - # Stack into one giant Torch tensor dataset_tensor = torch.tensor(np.array(results), dtype=torch.float32) print(f"Final Tensor Shape: {dataset_tensor.shape}") - # Shape: (100, 22, 96, 3) -> (Batch, Time, Landmarks, Coords) - # Initialize Model - num_unique_signs = asl_data["sign"].n_unique() - model = ASLClassifier(num_classes=num_unique_signs) - model.to(device) - # Test pass - output = model(dataset_tensor) - print(f"Model Output Shape: {output.shape}") # (100, 250) \ No newline at end of file + # Create dataset and split + full_dataset = ASLDataset(dataset_tensor, labels) + train_size = int(TRAIN_SPLIT * len(full_dataset)) + val_size = len(full_dataset) - train_size + train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) + + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) + + print(f"Train samples: {train_size}, Validation samples: {val_size}") + + # Initialize model, loss, optimizer + model = ASLClassifier(num_classes=len(unique_signs)).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) + + # Training loop + best_val_acc = 0.0 + + print("\n" + "=" * 50) + print("Starting Training") + print("=" * 50 + "\n") + + for epoch in range(EPOCHS): + print(f"\nEpoch [{epoch + 1}/{EPOCHS}]") + print("-" * 50) + + train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) + val_loss, val_acc = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + print(f"\nEpoch {epoch + 1} Summary:") + print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") + print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") + print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}") + + # Save checkpoint if validation accuracy improved + if val_acc > best_val_acc: + best_val_acc = val_acc + save_checkpoint(model, optimizer, epoch + 1, train_loss, val_loss, val_acc, CHECKPOINT_DIR) + print(f" ✓ New best validation accuracy: {best_val_acc:.2f}%") + + print("\n" + "=" * 50) + print("Training Complete!") + print(f"Best Validation Accuracy: {best_val_acc:.2f}%") + print("=" * 50) \ No newline at end of file