diff --git a/rewrite_training.py b/rewrite_training.py index 4e78d91..cefafb1 100644 --- a/rewrite_training.py +++ b/rewrite_training.py @@ -1,237 +1,759 @@ import os -import polars as pl +import json +import math import numpy as np +import polars as pl import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader, random_split -from concurrent.futures import ProcessPoolExecutor +import torch.optim as optim + +from torch.utils.data import Dataset, DataLoader +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder +from multiprocessing import Pool, cpu_count +from functools import partial from tqdm import tqdm +from collections import Counter -# --- CONFIGURATION --- -BASE_PATH = "asl_kaggle" -TARGET_FRAMES = 22 -LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] -HANDS = list(range(468, 543)) -SELECTED_INDICES = LIPS + HANDS -NUM_FEATS = len(SELECTED_INDICES) * 3 +# =============================== +# GPU CONFIGURATION +# =============================== +print("=" * 60) +print("GPU CONFIGURATION") +print("=" * 60) -# Training hyperparameters -BATCH_SIZE = 32 -EPOCHS = 50 -LEARNING_RATE = 0.001 -TRAIN_SPLIT = 0.8 -CHECKPOINT_DIR = "checkpoints" +if torch.cuda.is_available(): + print(f"✓ CUDA available!") + print(f"✓ GPU: {torch.cuda.get_device_name(0)}") + device = torch.device('cuda:0') + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True +else: + print("✗ CUDA not available, using CPU") + device = torch.device('cpu') -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"Using device: {device}") +print("=" * 60) + +# =============================== +# SELECTED LANDMARK INDICES +# =============================== +IMPORTANT_FACE_INDICES = sorted(list(set([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 55, 65, 66, 105, 107, 336, 296, 334, + 33, 133, 160, 159, 158, 144, 145, 153, + 362, 263, 387, 386, 385, 373, 374, 380, + 1, 2, 98, 327, + 61, 185, 40, 39, 37, 0, 267, 269, 270, 409, + 291, 146, 91, 181, 84, 17, 314, 405, 321, 375, + 78, 191, 80, 81, 82, 13, 312, 311, 310, 415, + 308, 324, 318, 402, 317, 14, 87, 178, 88, 95 +]))) + +NUM_FACE_POINTS = len(IMPORTANT_FACE_INDICES) +NUM_HAND_POINTS = 21 * 2 +TOTAL_POINTS_PER_FRAME = NUM_HAND_POINTS + NUM_FACE_POINTS -# --- DATA PROCESSING --- -def load_kaggle_metadata(base_path): - return pl.read_csv(os.path.join(base_path, "train.csv")) +# =============================== +# DATA AUGMENTATION +# =============================== +def augment_sequence(x, modality_mask): + """Apply random augmentations to training data""" + x = x.copy() + + # Random temporal cropping (simulate different signing speeds) + if np.random.rand() < 0.3 and len(x) > 20: + start = np.random.randint(0, max(1, len(x) // 4)) + x = x[start:] + modality_mask = modality_mask[start:] + + # Random spatial scaling + if np.random.rand() < 0.5: + scale = np.random.uniform(0.85, 1.15) + x = x * scale + + # Random rotation (around z-axis for x,y coordinates) + if np.random.rand() < 0.5: + angle = np.random.uniform(-0.3, 0.3) + cos_a, sin_a = np.cos(angle), np.sin(angle) + + # Reshape to get xyz coordinates + x_reshaped = x.reshape(len(x), -1, 3) + x_rot = x_reshaped.copy() + x_rot[..., 0] = x_reshaped[..., 0] * cos_a - x_reshaped[..., 1] * sin_a + x_rot[..., 1] = x_reshaped[..., 0] * sin_a + x_reshaped[..., 1] * cos_a + x = x_rot.reshape(x.shape) + + # Random masking (simulate occlusion) - only for some frames + if np.random.rand() < 0.3: + n_mask = int(len(x) * 0.15) # mask 15% of frames + mask_indices = np.random.choice(len(x), n_mask, replace=False) + x[mask_indices] *= 0.1 # dim but don't completely zero + + # Random noise + if np.random.rand() < 0.4: + noise = np.random.normal(0, 0.02, x.shape) + x = x + noise + + # Random time warping (speed up or slow down) + if np.random.rand() < 0.3 and len(x) > 20: + speed = np.random.uniform(0.8, 1.2) + new_len = int(len(x) * speed) + new_len = min(new_len, len(x)) + indices = np.linspace(0, len(x) - 1, new_len).astype(int) + x = x[indices] + modality_mask = modality_mask[indices] + + return x, modality_mask -def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): - parquet_path = os.path.join(base_path, path) - df = pl.read_parquet(parquet_path) +# =============================== +# ENHANCED DATA EXTRACTION (POLARS) +# =============================== +def extract_multi_landmarks(path, min_valid_frames=3): + """ + Extract both hands + selected face landmarks with modality flags + Returns: dict with 'landmarks', 'left_hand_valid', 'right_hand_valid', 'face_valid' + """ + try: + df = pl.read_parquet(path) + seq = [] + left_valid_frames = [] + right_valid_frames = [] + face_valid_frames = [] - anchors = ( - df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) - .select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")]) - ) + all_types = df.select("type").unique().to_series().to_list() + has_data = any(t in all_types for t in ["left_hand", "right_hand", "face"]) - processed = ( - df.join(anchors, on="frame", how="left") - .with_columns([ - (pl.col("x") - pl.col("nx")).fill_null(0.0), - (pl.col("y") - pl.col("ny")).fill_null(0.0), - (pl.col("z") - pl.col("nz")).fill_null(0.0), - ]) - .sort(["frame", "type", "landmark_index"]) - ) + if not has_data: + return None - raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3) - reduced_tensor = raw_tensor[:, SELECTED_INDICES, :] + frames = sorted(df.select("frame").unique().to_series().to_list()) - curr_len = reduced_tensor.shape[0] - indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int) - return reduced_tensor[indices] + if len(frames) < min_valid_frames: + return None + + for frame in frames: + frame_df = df.filter(pl.col("frame") == frame) + frame_points = np.full((TOTAL_POINTS_PER_FRAME, 3), np.nan, dtype=np.float32) + + pos = 0 + left_valid = False + right_valid = False + face_valid = False + + # Left hand + left = frame_df.filter(pl.col("type") == "left_hand") + if left.height > 0: + valid_count = 0 + for i in range(21): + row = left.filter(pl.col("landmark_index") == i) + if row.height > 0: + coords = row.select(["x", "y", "z"]).row(0) + if all(c is not None for c in coords): + frame_points[pos] = coords + valid_count += 1 + pos += 1 + left_valid = (valid_count >= 10) + else: + pos += 21 + + # Right hand + right = frame_df.filter(pl.col("type") == "right_hand") + if right.height > 0: + valid_count = 0 + for i in range(21): + row = right.filter(pl.col("landmark_index") == i) + if row.height > 0: + coords = row.select(["x", "y", "z"]).row(0) + if all(c is not None for c in coords): + frame_points[pos] = coords + valid_count += 1 + pos += 1 + right_valid = (valid_count >= 10) + else: + pos += 21 + + # Face + face = frame_df.filter(pl.col("type") == "face") + if face.height > 0: + valid_count = 0 + for idx in IMPORTANT_FACE_INDICES: + row = face.filter(pl.col("landmark_index") == idx) + if row.height > 0: + coords = row.select(["x", "y", "z"]).row(0) + if all(c is not None for c in coords): + frame_points[pos] = coords + valid_count += 1 + pos += 1 + face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.3) + + valid_ratio = 1 - np.isnan(frame_points).mean() + if valid_ratio >= 0.20: + frame_points = np.nan_to_num(frame_points, nan=0.0) + seq.append(frame_points) + left_valid_frames.append(left_valid) + right_valid_frames.append(right_valid) + face_valid_frames.append(face_valid) + + if len(seq) < min_valid_frames: + return None + + return { + 'landmarks': np.stack(seq), + 'left_hand_valid': np.array(left_valid_frames), + 'right_hand_valid': np.array(right_valid_frames), + 'face_valid': np.array(face_valid_frames) + } + + except Exception as e: + return None -# --- DATASET CLASS --- -class ASLDataset(Dataset): - def __init__(self, tensors, labels): - self.tensors = tensors - self.labels = labels +def get_features_sequence(landmarks_data, max_frames=100): + """Enhanced feature extraction with separate modality processing""" + if landmarks_data is None: + return None, None, None - def __len__(self): - return len(self.tensors) + landmarks_3d = landmarks_data['landmarks'] + if len(landmarks_3d) == 0: + return None, None, None - def __getitem__(self, idx): - return self.tensors[idx], self.labels[idx] + T, N, _ = landmarks_3d.shape + + # Separate modalities for independent normalization + left_hand = landmarks_3d[:, :21, :] + right_hand = landmarks_3d[:, 21:42, :] + face = landmarks_3d[:, 42:, :] + + features_list = [] + + for modality, valid_mask in [ + (left_hand, landmarks_data['left_hand_valid']), + (right_hand, landmarks_data['right_hand_valid']), + (face, landmarks_data['face_valid']) + ]: + valid_frames = modality[valid_mask] if valid_mask.any() else modality + if len(valid_frames) > 0: + center = np.mean(valid_frames, axis=(0, 1), keepdims=True) + spread = np.std(valid_frames, axis=(0, 1), keepdims=True).max() + else: + center = 0 + spread = 1 + + modality_norm = (modality - center) / max(spread, 1e-6) + flat = modality_norm.reshape(T, -1) + + # Deltas + deltas = np.zeros_like(flat) + if T > 1: + deltas[1:] = flat[1:] - flat[:-1] + + features_list.append(flat) + features_list.append(deltas) + + features = np.concatenate(features_list, axis=1) + + modality_mask = np.stack([ + landmarks_data['left_hand_valid'], + landmarks_data['right_hand_valid'], + landmarks_data['face_valid'] + ], axis=1).astype(np.float32) + + # Pad/truncate + if T < max_frames: + pad = np.zeros((max_frames - T, features.shape[1]), dtype=np.float32) + features = np.concatenate([features, pad], axis=0) + + mask_pad = np.zeros((max_frames - T, 3), dtype=np.float32) + modality_mask = np.concatenate([modality_mask, mask_pad], axis=0) + + frame_mask = np.zeros(max_frames, dtype=bool) + frame_mask[:T] = True + else: + features = features[:max_frames] + modality_mask = modality_mask[:max_frames] + frame_mask = np.ones(max_frames, dtype=bool) + + features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) + features = np.clip(features, -30, 30) + + return features.astype(np.float32), frame_mask, modality_mask -# --- MODEL ARCHITECTURE --- -class ASLClassifier(nn.Module): - def __init__(self, num_classes, target_frames=TARGET_FRAMES, num_feats=NUM_FEATS): +def process_row(row_data, base_path, max_frames=100): + """Process a single row""" + path_rel, sign = row_data + path = os.path.join(base_path, path_rel) + if not os.path.exists(path): + return None, None, None, None + + try: + lm_data = extract_multi_landmarks(path) + if lm_data is None: + return None, None, None, None + + feat, frame_mask, modality_mask = get_features_sequence(lm_data, max_frames) + if feat is None: + return None, None, None, None + + return feat, frame_mask, modality_mask, sign + + except Exception: + return None, None, None, None + + +# =============================== +# MIXUP AUGMENTATION +# =============================== +def mixup_data(x, frame_mask, modality_mask, y, alpha=0.2): + """Mixup augmentation""" + if alpha > 0: + lam = np.random.beta(alpha, alpha) + else: + lam = 1 + + batch_size = x.size(0) + index = torch.randperm(batch_size).to(x.device) + + mixed_x = lam * x + (1 - lam) * x[index] + mixed_frame_mask = frame_mask | frame_mask[index] # Union of valid frames + mixed_modality_mask = torch.max(modality_mask, modality_mask[index]) + + y_a, y_b = y, y[index] + return mixed_x, mixed_frame_mask, mixed_modality_mask, y_a, y_b, lam + + +# =============================== +# ENHANCED MODEL WITH ATTENTION POOLING +# =============================== +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=128): super().__init__() - self.conv1 = nn.Conv1d(num_feats, 256, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm1d(256) - self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm1d(512) - self.pool = nn.MaxPool1d(2) - self.dropout = nn.Dropout(0.5) - self.fc = nn.Linear(512, num_classes) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): - x = x.view(x.shape[0], x.shape[1], -1) - x = x.transpose(1, 2) - x = F.relu(self.bn1(self.conv1(x))) - x = self.pool(x) - x = F.relu(self.bn2(self.conv2(x))) - x = self.pool(x) - x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) - x = self.dropout(x) - return self.fc(x) + return x + self.pe[:, :x.size(1)] -# --- TRAINING FUNCTIONS --- -def train_epoch(model, dataloader, criterion, optimizer, device): - model.train() - running_loss = 0.0 - correct = 0 - total = 0 +class ModalityAwareTransformer(nn.Module): + def __init__(self, input_dim, num_classes, d_model=512, nhead=8, num_layers=6, dropout=0.15): + super().__init__() - progress_bar = tqdm(dataloader, desc="Training") - for inputs, labels in progress_bar: - inputs, labels = inputs.to(device), labels.to(device) + # Main projection + self.proj = nn.Linear(input_dim, d_model) - optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() + # Modality embedding (3 modalities: left_hand, right_hand, face) + self.modality_embed = nn.Linear(3, d_model) - running_loss += loss.item() - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() + self.norm_in = nn.LayerNorm(d_model) + self.pos = PositionalEncoding(d_model) - progress_bar.set_postfix({ - 'loss': running_loss / (progress_bar.n + 1), - 'acc': 100 * correct / total - }) + enc_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=d_model * 4, + dropout=dropout, + activation='gelu', + batch_first=True, + norm_first=True + ) + self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) - epoch_loss = running_loss / len(dataloader) - epoch_acc = 100 * correct / total - return epoch_loss, epoch_acc + # Attention pooling + self.attention_pool = nn.Linear(d_model, 1) + + self.head = nn.Sequential( + nn.LayerNorm(d_model), + nn.Dropout(0.3), + nn.Linear(d_model, d_model // 2), + nn.GELU(), + nn.Dropout(0.2), + nn.Linear(d_model // 2, num_classes) + ) + + def forward(self, x, modality_mask=None, key_padding_mask=None): + # Project features + x = self.proj(x) + + # Add modality information + if modality_mask is not None: + mod_embed = self.modality_embed(modality_mask) + x = x + mod_embed + + x = self.norm_in(x) + x = self.pos(x) + x = self.encoder(x, src_key_padding_mask=key_padding_mask) + + # Attention-based pooling + attn_weights = self.attention_pool(x) # (B, T, 1) + if key_padding_mask is not None: + attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(-1), -1e9) + attn_weights = F.softmax(attn_weights, dim=1) + x = (x * attn_weights).sum(dim=1) + + return self.head(x) -def validate(model, dataloader, criterion, device): - model.eval() - running_loss = 0.0 - correct = 0 - total = 0 - - with torch.no_grad(): - for inputs, labels in tqdm(dataloader, desc="Validation"): - inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) - loss = criterion(outputs, labels) - - running_loss += loss.item() - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() - - val_loss = running_loss / len(dataloader) - val_acc = 100 * correct / total - return val_loss, val_acc +def load_kaggle_asl_data(base_path): + """Load training metadata using Polars""" + train_path = os.path.join(base_path, "train.csv") + train_df = pl.read_csv(train_path) + return train_df, None -def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, val_acc, checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) - checkpoint = { - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'train_loss': train_loss, - 'val_loss': val_loss, - 'val_acc': val_acc, - } - path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt') - torch.save(checkpoint, path) - print(f"Checkpoint saved: {path}") +# =============================== +# DATASET WITH AUGMENTATION +# =============================== +class ASLMultiDataset(Dataset): + def __init__(self, X, frame_masks, modality_masks, y, training=False, max_frames=100): + self.X = X + self.frame_masks = frame_masks + self.modality_masks = modality_masks + self.y = y + self.training = training + self.max_frames = max_frames + + def __len__(self): + return len(self.X) + + def __getitem__(self, idx): + x = self.X[idx].copy() + frame_mask = self.frame_masks[idx].copy() + modality_mask = self.modality_masks[idx].copy() + y = self.y[idx] + + if self.training: + # Apply augmentation + x, modality_mask = augment_sequence(x, modality_mask) + + # Re-pad if needed after augmentation + if len(x) < self.max_frames: + pad = np.zeros((self.max_frames - len(x), x.shape[1]), dtype=np.float32) + x = np.concatenate([x, pad], axis=0) + + mask_pad = np.zeros((self.max_frames - len(x), 3), dtype=np.float32) + modality_mask = np.concatenate([modality_mask, mask_pad], axis=0) + + frame_mask = np.zeros(self.max_frames, dtype=bool) + frame_mask[:len(x)] = True + else: + x = x[:self.max_frames] + modality_mask = modality_mask[:self.max_frames] + frame_mask = np.ones(self.max_frames, dtype=bool) + + return ( + torch.from_numpy(x).float(), + torch.from_numpy(frame_mask).bool(), + torch.from_numpy(modality_mask).float(), + torch.tensor(y, dtype=torch.long) + ) + + +# =============================== +# TRAINING SINGLE MODEL +# =============================== +def train_model(X_tr, fm_tr, mm_tr, y_tr, X_te, fm_te, mm_te, y_te, + num_classes, input_dim, model_idx=0, epochs=80): + """Train a single model""" + + # Set different seed for each model + torch.manual_seed(42 + model_idx) + np.random.seed(42 + model_idx) + + batch_size = 64 if device.type == 'cuda' else 32 + + train_dataset = ASLMultiDataset(X_tr, fm_tr, mm_tr, y_tr, training=True, max_frames=100) + test_dataset = ASLMultiDataset(X_te, fm_te, mm_te, y_te, training=False, max_frames=100) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, shuffle=True, + num_workers=4, pin_memory=device.type == 'cuda' + ) + + test_loader = DataLoader( + test_dataset, + batch_size=batch_size * 2, shuffle=False, + num_workers=4, pin_memory=device.type == 'cuda' + ) + + # Enhanced model + model = ModalityAwareTransformer( + input_dim=input_dim, + num_classes=num_classes, + d_model=512, + nhead=8, + num_layers=6, + dropout=0.15 + ).to(device) + + print(f"\n[Model {model_idx + 1}] Parameters: {sum(p.numel() for p in model.parameters()):,}") + + criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) + + # OneCycleLR scheduler + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=1e-3, + steps_per_epoch=len(train_loader), + epochs=epochs, + pct_start=0.1, + anneal_strategy='cos' + ) + + best_acc = 0.0 + save_path = f"best_asl_model_{model_idx}.pth" + + print(f"\n{'=' * 60}") + print(f"TRAINING MODEL {model_idx + 1}") + print(f"{'=' * 60}") + + for epoch in range(epochs): + model.train() + total_loss = correct = total = 0 + + for x, frame_mask, modality_mask, yb in tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False): + x = x.to(device) + frame_mask = frame_mask.to(device) + modality_mask = modality_mask.to(device) + yb = yb.to(device) + + # Apply mixup + if np.random.rand() < 0.5: + x, frame_mask, modality_mask, y_a, y_b, lam = mixup_data( + x, frame_mask, modality_mask, yb, alpha=0.2 + ) + + key_padding_mask = ~frame_mask + optimizer.zero_grad(set_to_none=True) + logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask) + loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b) + + # Use original labels for accuracy + correct += (logits.argmax(-1) == yb).sum().item() + else: + key_padding_mask = ~frame_mask + optimizer.zero_grad(set_to_none=True) + logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask) + loss = criterion(logits, yb) + correct += (logits.argmax(-1) == yb).sum().item() + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + total_loss += loss.item() + total += yb.size(0) + + train_acc = correct / total * 100 + + # Eval + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, frame_mask, modality_mask, yb in test_loader: + x = x.to(device) + frame_mask = frame_mask.to(device) + modality_mask = modality_mask.to(device) + yb = yb.to(device) + + key_padding_mask = ~frame_mask + logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask) + correct += (logits.argmax(-1) == yb).sum().item() + total += yb.size(0) + + test_acc = correct / total * 100 + + print(f"[{epoch + 1:2d}/{epochs}] Loss: {total_loss / len(train_loader):.4f} | " + f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}%", end="") + + if test_acc > best_acc: + best_acc = test_acc + torch.save(model.state_dict(), save_path) + print(" → saved") + else: + print() + + print(f"\nModel {model_idx + 1} - Best test accuracy: {best_acc:.2f}%") + return save_path, best_acc + + +# =============================== +# ENSEMBLE PREDICTION +# =============================== +def ensemble_predict(model_paths, test_loader, num_classes, input_dim): + """Make predictions using ensemble of models""" + all_preds = [] + + for model_path in model_paths: + model = ModalityAwareTransformer( + input_dim=input_dim, + num_classes=num_classes, + d_model=512, + nhead=8, + num_layers=6 + ).to(device) + + model.load_state_dict(torch.load(model_path)) + model.eval() + + preds = [] + with torch.no_grad(): + for x, frame_mask, modality_mask, _ in test_loader: + x = x.to(device) + frame_mask = frame_mask.to(device) + modality_mask = modality_mask.to(device) + + key_padding_mask = ~frame_mask + logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask) + preds.append(F.softmax(logits, dim=-1)) + + all_preds.append(torch.cat(preds, dim=0)) + + # Average predictions + ensemble_pred = torch.stack(all_preds).mean(0) + return ensemble_pred.argmax(-1).cpu().numpy() + + +# =============================== +# MAIN +# =============================== +def main(): + base_path = "asl_kaggle" + max_frames = 100 + MIN_SAMPLES_PER_CLASS = 3 # Relaxed from 5 + NUM_ENSEMBLE_MODELS = 3 + EPOCHS = 80 + + print("\nLoading metadata...") + train_df, _ = load_kaggle_asl_data(base_path) + print(f"Total samples in train.csv: {train_df.height}") + + rows = [(row[0], row[1]) for row in train_df.select(["path", "sign"]).iter_rows()] + + print("\nProcessing sequences with BOTH hands + FACE (enhanced)...") + print("This may take a few minutes...") + + with Pool(cpu_count()) as pool: + results = list(tqdm( + pool.imap( + partial(process_row, base_path=base_path, max_frames=max_frames), + rows, + chunksize=80 + ), + total=len(rows), + desc="Landmarks extraction" + )) + + X_list, frame_masks_list, modality_masks_list, y_list = [], [], [], [] + failed_count = 0 + for feat, frame_mask, modality_mask, sign in results: + if feat is not None and frame_mask is not None: + X_list.append(feat) + frame_masks_list.append(frame_mask) + modality_masks_list.append(modality_mask) + y_list.append(sign) + else: + failed_count += 1 + + if not X_list: + print(f"\n❌ No valid sequences extracted!") + print(f"Failed to process: {failed_count}/{len(results)} files") + return + + print(f"\n✓ Successfully processed: {len(X_list)}/{len(results)} files") + print(f"✗ Failed: {failed_count}/{len(results)} files") + + X = np.stack(X_list) + frame_masks = np.stack(frame_masks_list) + modality_masks = np.stack(modality_masks_list) + + print(f"\nExtracted {len(X):,} sequences") + print(f"Feature shape: {X.shape[1:]} (input_dim = {X.shape[2]})") + + # Global normalization + X = np.clip(X, -30, 30) + mean = X.mean(axis=(0, 1), keepdims=True) + std = X.std(axis=(0, 1), keepdims=True) + 1e-8 + X = (X - mean) / std + + # Labels + le = LabelEncoder() + y = le.fit_transform(y_list) + + # Filter rare classes + counts = Counter(y) + valid = [k for k, v in counts.items() if v >= MIN_SAMPLES_PER_CLASS] + mask = np.isin(y, valid) + + X = X[mask] + frame_masks = frame_masks[mask] + modality_masks = modality_masks[mask] + y = y[mask] + + le = LabelEncoder() + y = le.fit_transform(y) + + print(f"After filtering: {len(X):,} samples | {len(le.classes_)} classes") + + # Split + X_tr, X_te, fm_tr, fm_te, mm_tr, mm_te, y_tr, y_te = train_test_split( + X, frame_masks, modality_masks, y, test_size=0.15, stratify=y, random_state=42 + ) + + # Train ensemble of models + model_paths = [] + best_accs = [] + + for i in range(NUM_ENSEMBLE_MODELS): + model_path, best_acc = train_model( + X_tr, fm_tr, mm_tr, y_tr, + X_te, fm_te, mm_te, y_te, + num_classes=len(le.classes_), + input_dim=X.shape[2], + model_idx=i, + epochs=EPOCHS + ) + model_paths.append(model_path) + best_accs.append(best_acc) + + # Ensemble evaluation + print("\n" + "=" * 60) + print("ENSEMBLE EVALUATION") + print("=" * 60) + + test_dataset = ASLMultiDataset(X_te, fm_te, mm_te, y_te, training=False) + test_loader = DataLoader( + test_dataset, + batch_size=128, + shuffle=False, + num_workers=4, + pin_memory=device.type == 'cuda' + ) + + ensemble_preds = ensemble_predict(model_paths, test_loader, len(le.classes_), X.shape[2]) + ensemble_acc = (ensemble_preds == y_te).mean() * 100 + + print(f"\nIndividual model accuracies:") + for i, acc in enumerate(best_accs): + print(f" Model {i + 1}: {acc:.2f}%") + + print(f"\n🎯 Ensemble accuracy: {ensemble_acc:.2f}%") + print(f" Improvement: +{ensemble_acc - max(best_accs):.2f}% over best single model") + + print("\n" + "=" * 60) + print(f"TRAINING COMPLETE") + print("=" * 60) -# --- EXECUTION --- if __name__ == "__main__": - # Load metadata - asl_data = load_kaggle_metadata(BASE_PATH) - - # Create label mapping - unique_signs = sorted(asl_data["sign"].unique().to_list()) - label_to_idx = {sign: idx for idx, sign in enumerate(unique_signs)} - labels = torch.tensor([label_to_idx[sign] for sign in asl_data["sign"].to_list()]) - - print(f"Number of classes: {len(unique_signs)}") - - # Process data in parallel - paths = asl_data["path"].to_list() - print(f"Processing {len(paths)} files in parallel...") - - with ProcessPoolExecutor() as executor: - results = list(tqdm(executor.map(load_and_preprocess, paths), total=len(paths))) - - dataset_tensor = torch.tensor(np.array(results), dtype=torch.float32) - print(f"Final Tensor Shape: {dataset_tensor.shape}") - - # Create dataset and split - full_dataset = ASLDataset(dataset_tensor, labels) - train_size = int(TRAIN_SPLIT * len(full_dataset)) - val_size = len(full_dataset) - train_size - train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) - - train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) - val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) - - print(f"Train samples: {train_size}, Validation samples: {val_size}") - - # Initialize model, loss, optimizer - model = ASLClassifier(num_classes=len(unique_signs)).to(device) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) - - # Training loop - best_val_acc = 0.0 - - print("\n" + "=" * 50) - print("Starting Training") - print("=" * 50 + "\n") - - for epoch in range(EPOCHS): - print(f"\nEpoch [{epoch + 1}/{EPOCHS}]") - print("-" * 50) - - train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) - val_loss, val_acc = validate(model, val_loader, criterion, device) - - scheduler.step(val_loss) - - print(f"\nEpoch {epoch + 1} Summary:") - print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") - print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") - print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}") - - # Save checkpoint if validation accuracy improved - if val_acc > best_val_acc: - best_val_acc = val_acc - save_checkpoint(model, optimizer, epoch + 1, train_loss, val_loss, val_acc, CHECKPOINT_DIR) - print(f" ✓ New best validation accuracy: {best_val_acc:.2f}%") - - print("\n" + "=" * 50) - print("Training Complete!") - print(f"Best Validation Accuracy: {best_val_acc:.2f}%") - print("=" * 50) \ No newline at end of file + main() \ No newline at end of file