diff --git a/rewrite_training.py b/rewrite_training.py index df84352..3d6e1a7 100644 --- a/rewrite_training.py +++ b/rewrite_training.py @@ -4,25 +4,33 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import TensorDataset, DataLoader -from concurrent.futures import ProcessPoolExecutor +from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from sklearn.model_selection import train_test_split +from concurrent.futures import ProcessPoolExecutor # --- CONFIG --- BASE_PATH = "asl_kaggle" +CACHE_DIR = "asl_cache" TARGET_FRAMES = 22 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# --- DATA LOADING WITH RELATIVE FEATURES --- +# --- PREPROCESSING (RUN ONCE) --- + +def process_single_file(args): + """Process a single file - designed for multiprocessing""" + i, path, base_path, cache_dir = args + cache_path = os.path.join(cache_dir, f"sample_{i}.npy") + + if os.path.exists(cache_path): + return # Skip if already cached -def load_file_to_memory(path, base_path=BASE_PATH): try: parquet_path = os.path.join(base_path, path) df = pl.read_parquet(parquet_path) - # 1. Global Anchor (Nose) + # Global Anchor (Nose) anchors = ( df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) .select([ @@ -33,8 +41,7 @@ def load_file_to_memory(path, base_path=BASE_PATH): ]) ) - # 2. Local Anchors (Wrists) - # Left: 468, Right: 522 + # Local Anchors (Wrists) wrists = ( df.filter(pl.col("landmark_index").is_in([468, 522])) .select([ @@ -47,40 +54,83 @@ def load_file_to_memory(path, base_path=BASE_PATH): processed = df.join(anchors, on="frame", how="left") - # Join wrist data to the main frame - # We use a left join on frame and landmark_index to align wrist coords with their rows processed = ( processed.join(wrists, on=["frame", "landmark_index"], how="left") .with_columns([ - # Global (Nose-relative) (pl.col("x") - pl.col("nx")).alias("x_g"), (pl.col("y") - pl.col("ny")).alias("y_g"), (pl.col("z") - pl.col("nz")).alias("z_g"), - # Local (Wrist-relative - defaults to global if not a hand point) (pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"), (pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"), ]) .sort(["frame", "type", "landmark_index"]) ) - # We now have 5 channels: (x_g, y_g, z_g, x_l, y_l) n_frames = processed["frame"].n_unique() - # Reshape to (Frames, 543 landmarks, 5 features) tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5) # Temporal Resampling indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int) - return tensor[indices] + result = tensor[indices] + + # Save to cache + np.save(cache_path, result) except Exception: - return np.zeros((TARGET_FRAMES, 543, 5)) + # Save zero tensor for failed files + np.save(cache_path, np.zeros((TARGET_FRAMES, 543, 5))) -# --- DUAL-STREAM MODEL --- +def preprocess_and_cache(paths, base_path=BASE_PATH, cache_dir=CACHE_DIR): + """Preprocess all files in parallel and save as numpy arrays""" + os.makedirs(cache_dir, exist_ok=True) + + # Check if already cached + all_cached = all(os.path.exists(os.path.join(cache_dir, f"sample_{i}.npy")) for i in range(len(paths))) + if all_cached: + print("All files already cached, skipping preprocessing...") + return + + print(f"Preprocessing {len(paths)} files in parallel...") + + # Create arguments for each file + args_list = [(i, path, base_path, cache_dir) for i, path in enumerate(paths)] + + # Process in parallel + with ProcessPoolExecutor() as executor: + list(tqdm(executor.map(process_single_file, args_list), total=len(args_list))) + + print("Preprocessing complete!") + + +# --- FAST DATASET (LOADS FROM CACHE) --- + +class CachedASLDataset(Dataset): + """Fast dataset that loads from preprocessed numpy files""" + + def __init__(self, indices, labels, cache_dir=CACHE_DIR): + self.indices = indices + self.labels = labels + self.cache_dir = cache_dir + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + sample_idx = self.indices[idx] + cache_path = os.path.join(self.cache_dir, f"sample_{sample_idx}.npy") + + # Fast numpy load + data = np.load(cache_path) + label = self.labels[idx] + + return torch.tensor(data, dtype=torch.float32), torch.tensor(label, dtype=torch.long) + + +# --- MODEL --- class ASLClassifier(nn.Module): def __init__(self, num_classes): super().__init__() - # 543 landmarks * 5 features per landmark = 2715 self.feat_dim = 543 * 5 self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1) @@ -99,10 +149,8 @@ class ASLClassifier(nn.Module): ) def forward(self, x): - # x shape: (Batch, 22, 543, 5) b, t, l, f = x.shape - # Flatten landmarks and features into one vector, then transpose for Conv1d - x = x.view(b, t, -1).transpose(1, 2) # (Batch, 2715, 22) + x = x.view(b, t, -1).transpose(1, 2) x = F.relu(self.bn1(self.conv1(x))) x = self.pool(x) @@ -110,7 +158,6 @@ class ASLClassifier(nn.Module): x = F.relu(self.bn2(self.conv2(x))) x = self.pool(x) - # Global Average Pool across the time dimension x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) return self.fc(self.dropout(x)) @@ -119,38 +166,51 @@ class ASLClassifier(nn.Module): # --- EXECUTION --- if __name__ == "__main__": - # 1. Setup Data + # 1. Setup Metadata metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv")) unique_signs = sorted(metadata["sign"].unique().to_list()) sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)} - labels = [sign_to_idx[s] for s in metadata["sign"].to_list()] paths = metadata["path"].to_list() + labels = [sign_to_idx[s] for s in metadata["sign"].to_list()] - # 2. Load to RAM (Parallelized) - print(f"Loading {len(paths)} files into RAM with 5-channel features...") - with ProcessPoolExecutor() as executor: - data_list = list(tqdm(executor.map(load_file_to_memory, paths), total=len(paths))) + # 2. Preprocess and cache (parallelized, only runs if cache doesn't exist) + preprocess_and_cache(paths) - X = torch.tensor(np.array(data_list), dtype=torch.float32) - y = torch.tensor(labels, dtype=torch.long) + # 3. Create index mapping for train/val split + all_indices = list(range(len(paths))) + train_indices, val_indices, train_labels, val_labels = train_test_split( + all_indices, labels, test_size=0.1, stratify=labels, random_state=42 + ) - # 3. Split - X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42) + # 4. Create datasets from cached files + train_dataset = CachedASLDataset(train_indices, train_labels) + val_dataset = CachedASLDataset(val_indices, val_labels) - train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True) - val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64) + # Increase batch size and workers since loading is now fast + train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=64, num_workers=4, pin_memory=True) - # 4. Train + print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") + + # 5. Train model = ASLClassifier(len(unique_signs)).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Helps prevent over-confidence + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5) + criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + best_acc = 0.0 print(f"Starting training on {device}...") + for epoch in range(25): + # Training model.train() train_loss = 0 - for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}"): + train_correct = 0 + train_total = 0 + + pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/25 [Train]") + for batch_x, batch_y in pbar: batch_x, batch_y = batch_x.to(device), batch_y.to(device) optimizer.zero_grad() @@ -158,19 +218,47 @@ if __name__ == "__main__": loss = criterion(output, batch_y) loss.backward() optimizer.step() + train_loss += loss.item() + _, predicted = torch.max(output, 1) + train_total += batch_y.size(0) + train_correct += (predicted == batch_y).sum().item() + + pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100 * train_correct / train_total:.1f}%'}) # Validation model.eval() - correct, total = 0, 0 + val_correct, val_total = 0, 0 + val_loss = 0 + with torch.no_grad(): - for vx, vy in val_loader: + for vx, vy in tqdm(val_loader, desc=f"Epoch {epoch + 1}/25 [Val]"): vx, vy = vx.to(device), vy.to(device) - pred = model(vx).argmax(1) - correct += (pred == vy).sum().item() - total += vy.size(0) + output = model(vx) + val_loss += criterion(output, vy).item() + pred = output.argmax(1) + val_correct += (pred == vy).sum().item() + val_total += vy.size(0) - print(f"Epoch {epoch + 1} | Loss: {train_loss / len(train_loader):.4f} | Val Acc: {100 * correct / total:.2f}%") + avg_train_loss = train_loss / len(train_loader) + avg_val_loss = val_loss / len(val_loader) + train_acc = 100 * train_correct / train_total + val_acc = 100 * val_correct / val_total + scheduler.step(avg_val_loss) + + print(f"\nEpoch {epoch + 1}/25:") + print(f" Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%") + print(f" Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%") + + # Save best model + if val_acc > best_acc: + best_acc = val_acc + torch.save(model.state_dict(), "best_asl_model.pth") + print(f" ✓ Best model saved! (Val Acc: {val_acc:.2f}%)\n") + + # Checkpoint every 5 epochs if (epoch + 1) % 5 == 0: - torch.save(model.state_dict(), f"asl_model_v2_e{epoch + 1}.pth") \ No newline at end of file + torch.save(model.state_dict(), f"asl_model_e{epoch + 1}.pth") + + print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%") \ No newline at end of file