diff --git a/rewrite_training.py b/rewrite_training.py index 8a11a07..27bbe2f 100644 --- a/rewrite_training.py +++ b/rewrite_training.py @@ -4,33 +4,27 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import TensorDataset, DataLoader -from concurrent.futures import ProcessPoolExecutor -from tqdm import tqdm +from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split +from tqdm import tqdm # --- CONFIGURATION --- BASE_PATH = "asl_kaggle" TARGET_FRAMES = 22 -# Hand landmarks + Lip landmarks (approximate indices for high-value face points) LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] HANDS = list(range(468, 543)) SELECTED_INDICES = LIPS + HANDS -NUM_FEATS = len(SELECTED_INDICES) * 3 # X, Y, Z for each selected point +NUM_FEATS = len(SELECTED_INDICES) * 3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# --- DATA PROCESSING --- - -def load_kaggle_metadata(base_path): - return pl.read_csv(os.path.join(base_path, "train.csv")) - +# --- DATASET ENGINE --- def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): parquet_path = os.path.join(base_path, path) df = pl.read_parquet(parquet_path) - # 1. Spatial Normalization (Nose Anchor) + # 1. Spatial Normalization anchors = ( df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) .select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")]) @@ -46,25 +40,41 @@ def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): .sort(["frame", "type", "landmark_index"]) ) - # 2. Reshape & Feature Selection - # Get unique frames and total landmarks (543) + # 2. Slice and Reshape raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3) - - # Slice to keep only Hands and Lips reduced_tensor = raw_tensor[:, SELECTED_INDICES, :] - # 3. Temporal Normalization (Resample to fixed frame count) + # 3. Temporal Resampling curr_len = reduced_tensor.shape[0] indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int) return reduced_tensor[indices] -# --- MODEL ARCHITECTURE --- +class ASLDataset(Dataset): + def __init__(self, paths, labels): + self.paths = paths + self.labels = labels + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + try: + x = load_and_preprocess(self.paths[idx]) + y = self.labels[idx] + return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long) + except Exception as e: + # Return a zero tensor if a file is corrupted to prevent crash + return torch.zeros((TARGET_FRAMES, len(SELECTED_INDICES), 3)), torch.tensor(self.labels[idx], + dtype=torch.long) + + +# --- MODEL --- class ASLClassifier(nn.Module): - def __init__(self, num_classes, target_frames=TARGET_FRAMES, num_feats=NUM_FEATS): + def __init__(self, num_classes): super().__init__() - self.conv1 = nn.Conv1d(num_feats, 256, kernel_size=3, padding=1) + self.conv1 = nn.Conv1d(NUM_FEATS, 256, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(256) self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(512) @@ -73,9 +83,9 @@ class ASLClassifier(nn.Module): self.fc = nn.Linear(512, num_classes) def forward(self, x): - # x: (Batch, Frames, Selected_Landmarks, 3) - x = x.view(x.shape[0], x.shape[1], -1) # Flatten landmarks/coords - x = x.transpose(1, 2) # (Batch, Features, Time) + # x shape: (Batch, 22, 96, 3) + b, t, l, c = x.shape + x = x.view(b, t, -1).transpose(1, 2) # (Batch, Features, Time) x = F.relu(self.bn1(self.conv1(x))) x = self.pool(x) @@ -87,123 +97,73 @@ class ASLClassifier(nn.Module): return self.fc(x) -# --- TRAINING FUNCTION --- +# --- TRAINING LOOP --- -def train_model(model, train_loader, val_loader, epochs=20, lr=0.001): - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) +def run_training(): + # 1. Prepare Metadata + train_df = pl.read_csv(os.path.join(BASE_PATH, "train.csv")) + unique_signs = sorted(train_df["sign"].unique().to_list()) + sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)} - best_val_acc = 0.0 + paths = train_df["path"].to_list() + labels = [sign_to_idx[s] for s in train_df["sign"].to_list()] - for epoch in range(epochs): - # Training phase + # 2. Split + p_train, p_val, l_train, l_val = train_test_split(paths, labels, test_size=0.15, stratify=labels) + + # 3. Loaders + train_ds = ASLDataset(p_train, l_train) + val_ds = ASLDataset(p_val, l_val) + + # num_workers=4 allows the CPU to preprocess the next batch while GPU trains + train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) + + # 4. Init Model & Optim + model = ASLClassifier(len(unique_signs)).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + scaler = torch.amp.GradScaler(enabled=(device.type == 'cuda')) + + # 5. Loop + for epoch in range(30): model.train() - train_loss = 0.0 - train_correct = 0 - train_total = 0 + t_correct, t_total = 0, 0 - pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]") - for inputs, labels in pbar: - inputs, labels = inputs.to(device), labels.to(device) + pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") + for x, y in pbar: + x, y = x.to(device), y.to(device) optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - train_loss += loss.item() - _, predicted = torch.max(outputs, 1) - train_total += labels.size(0) - train_correct += (predicted == labels).sum().item() + # Use Mixed Precision for speed + with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')): + outputs = model(x) + loss = criterion(outputs, y) - pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100 * train_correct / train_total:.2f}%'}) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() - train_acc = 100 * train_correct / train_total - avg_train_loss = train_loss / len(train_loader) + _, pred = torch.max(outputs, 1) + t_total += y.size(0) + t_correct += (pred == y).sum().item() + pbar.set_postfix(acc=f"{(t_correct / t_total) * 100:.1f}%") - # Validation phase + # Validation model.eval() - val_loss = 0.0 - val_correct = 0 - val_total = 0 - + v_correct, v_total = 0, 0 with torch.no_grad(): - for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} [Val]"): - inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) - loss = criterion(outputs, labels) + for x, y in val_loader: + x, y = x.to(device), y.to(device) + outputs = model(x) + _, pred = torch.max(outputs, 1) + v_total += y.size(0) + v_correct += (pred == y).sum().item() - val_loss += loss.item() - _, predicted = torch.max(outputs, 1) - val_total += labels.size(0) - val_correct += (predicted == labels).sum().item() + print(f"Validation Accuracy: {(v_correct / v_total) * 100:.2f}%") + torch.save(model.state_dict(), f"asl_model_epoch_{epoch}.pth") - val_acc = 100 * val_correct / val_total - avg_val_loss = val_loss / len(val_loader) - - scheduler.step(avg_val_loss) - - print(f"\nEpoch {epoch + 1}/{epochs}:") - print(f" Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%") - print(f" Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%") - - # Save best model - if val_acc > best_val_acc: - best_val_acc = val_acc - torch.save(model.state_dict(), 'best_asl_model.pth') - print(f" ✓ New best model saved! (Val Acc: {val_acc:.2f}%)") - - print() - - print(f"Training complete! Best validation accuracy: {best_val_acc:.2f}%") - - -# --- EXECUTION --- if __name__ == "__main__": - asl_data = load_kaggle_metadata(BASE_PATH) - - # Process all files - paths = asl_data["path"].to_list() - labels = asl_data["sign"].to_list() - - # Create label mapping - unique_signs = sorted(set(labels)) - sign_to_idx = {sign: idx for idx, sign in enumerate(unique_signs)} - label_indices = [sign_to_idx[sign] for sign in labels] - - print(f"Processing {len(paths)} files in parallel...") - with ProcessPoolExecutor() as executor: - results = list(tqdm(executor.map(load_and_preprocess, paths), total=len(paths))) - - # Create tensors - X = torch.tensor(np.array(results), dtype=torch.float32) - y = torch.tensor(label_indices, dtype=torch.long) - - print(f"Dataset Tensor Shape: {X.shape}") - print(f"Labels Tensor Shape: {y.shape}") - print(f"Number of unique signs: {len(unique_signs)}") - - # Train/Val split - X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) - - # Create DataLoaders - train_dataset = TensorDataset(X_train, y_train) - val_dataset = TensorDataset(X_val, y_val) - - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) - - print(f"Train samples: {len(train_dataset)}") - print(f"Val samples: {len(val_dataset)}") - - # Initialize and train model - model = ASLClassifier(num_classes=len(unique_signs)) - model.to(device) - - print(f"\nModel initialized with {sum(p.numel() for p in model.parameters()):,} parameters") - print("Starting training...\n") - - train_model(model, train_loader, val_loader, epochs=20, lr=0.001) \ No newline at end of file + run_training() \ No newline at end of file