From f66049f6ea8a828414fa43353b66fb57c939e7ba Mon Sep 17 00:00:00 2001 From: Stupdi Go Date: Sat, 10 Jan 2026 23:04:48 -0600 Subject: [PATCH] grok lock in --- training.py | 293 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 170 insertions(+), 123 deletions(-) diff --git a/training.py b/training.py index 28fb894..49b0f78 100644 --- a/training.py +++ b/training.py @@ -4,23 +4,21 @@ import os import json import math -import time import numpy as np import pandas as pd - import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim - from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split -from sklearn.preprocessing import LabelEncoder +from sklearn.preprocessing import StandardScaler from multiprocessing import Pool, cpu_count from functools import partial +from tqdm import tqdm # =============================== -# GPU SETUP +# DEVICE # =============================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -29,7 +27,7 @@ if device.type == "cuda": torch.backends.cudnn.benchmark = True # =============================== -# DATA LOADING & FEATURE EXTRACTION +# DATA LOADING # =============================== def load_kaggle_asl_data(base_path): train_df = pd.read_csv(os.path.join(base_path, "train.csv")) @@ -38,111 +36,117 @@ def load_kaggle_asl_data(base_path): return train_df, sign_to_idx def extract_hand_landmarks_from_parquet(path): - df = pd.read_parquet(path) - left = df[df["type"] == "left_hand"] - right = df[df["type"] == "right_hand"] + try: + df = pd.read_parquet(path) + hand = df[df["type"].isin(["left_hand", "right_hand"])] + if len(hand) == 0: + return None - hand = None - if len(left) > 0: - hand = left - elif len(right) > 0: - hand = right - else: + frames = sorted(hand['frame'].unique()) + landmarks_seq = [] + + for frame in frames: + lm_frame = hand[hand['frame'] == frame] + lm_list = [] + for i in range(21): + lm = lm_frame[lm_frame['landmark_index'] == i] + if len(lm) == 0: + lm_list.append([0.0, 0.0, 0.0]) + else: + lm_list.append([lm['x'].values[0], lm['y'].values[0], lm['z'].values[0]]) + landmarks_seq.append(lm_list) + + return np.array(landmarks_seq, dtype=np.float32) # (T, 21, 3) + except: return None - # Keep all frames - frames = sorted(hand['frame'].unique()) - landmarks_seq = [] - - for frame in frames: - lm_frame = hand[hand['frame'] == frame] - lm_list = [] - for i in range(21): - lm = lm_frame[lm_frame['landmark_index'] == i] - if len(lm) == 0: - lm_list.append([0.0, 0.0, 0.0]) - else: - lm_list.append([ - lm['x'].mean(), - lm['y'].mean(), - lm['z'].mean() - ]) - landmarks_seq.append(lm_list) - - return np.array(landmarks_seq, dtype=np.float32) # (T, 21, 3) - -def get_features_sequence(landmarks_seq, max_frames=100): - if landmarks_seq is None: +def get_features_sequence(landmarks_seq, max_frames=96): + if landmarks_seq is None or len(landmarks_seq) == 0: return None - # Center on wrist - points = landmarks_seq - landmarks_seq[:, 0:1, :] - scale = np.linalg.norm(points[:, 9, :], axis=1, keepdims=True) - scale[scale < 1e-6] = 1.0 - points /= scale[:, np.newaxis, :] - # Flatten per frame - frames = points.reshape(points.shape[0], -1) - # Pad or truncate - if frames.shape[0] < max_frames: - pad = np.zeros((max_frames - frames.shape[0], frames.shape[1]), dtype=np.float32) - frames = np.vstack([frames, pad]) - else: - frames = frames[:max_frames] - return frames # (max_frames, 63) -def process_row(row, base_path, max_frames=100): + # Center on wrist (landmark 0) + landmarks_seq = landmarks_seq - landmarks_seq[:, 0:1, :] + + # Rough scale normalization (using index finger length as reference) + scale = np.linalg.norm(landmarks_seq[:, 8] - landmarks_seq[:, 5], axis=1, keepdims=True) + scale = np.maximum(scale, 1e-6) + landmarks_seq /= scale + + # Flatten → (T, 63) + seq = landmarks_seq.reshape(landmarks_seq.shape[0], -1) + + # Pad / truncate + if len(seq) < max_frames: + pad = np.zeros((max_frames - len(seq), seq.shape[1]), dtype=np.float32) + seq = np.concatenate([seq, pad], axis=0) + else: + seq = seq[:max_frames] + + return seq.astype(np.float32) + +def process_row(row, base_path, max_frames=96): path = os.path.join(base_path, row['path']) if not os.path.exists(path): return None, None - try: - lm_seq = extract_hand_landmarks_from_parquet(path) - feat_seq = get_features_sequence(lm_seq, max_frames) - return feat_seq, row['sign'] - except: + lm = extract_hand_landmarks_from_parquet(path) + feat = get_features_sequence(lm, max_frames) + if feat is None: return None, None + return feat, row['sign'] # =============================== -# LOAD + PROCESS DATA +# LOAD & PROCESS (with progress) # =============================== -base_path = "asl_kaggle" +base_path = "asl_kaggle" # ← change if needed train_df, sign_to_idx = load_kaggle_asl_data(base_path) +print("Processing videos...") rows = [row for _, row in train_df.iterrows()] -X, y = [], [] -func = partial(process_row, base_path=base_path, max_frames=100) with Pool(cpu_count()) as pool: - for feat_seq, sign in pool.map(func, rows): - if feat_seq is not None: - X.append(feat_seq) - y.append(sign) + results = list(tqdm(pool.imap( + partial(process_row, base_path=base_path, max_frames=96), + rows + ), total=len(rows))) + +X, y = [], [] +for feat, sign in results: + if feat is not None: + X.append(feat) + y.append(sign) X = np.stack(X) # (N, T, 63) -y = np.array(y) -print("Samples:", len(X)) -print("Sequence shape:", X.shape[1:]) +print(f"Loaded {len(X)} valid samples | shape: {X.shape}") + +# Global normalization (very important!) +scaler = StandardScaler() +X_reshaped = X.reshape(-1, X.shape[-1]) +X_reshaped = scaler.fit_transform(X_reshaped) +X = X_reshaped.reshape(X.shape) # =============================== -# LABEL ENCODING +# LABELS # =============================== +from sklearn.preprocessing import LabelEncoder le = LabelEncoder() y = le.fit_transform(y) num_classes = len(le.classes_) -print("Num classes:", num_classes) +print(f"Classes: {num_classes}") # =============================== # SPLIT # =============================== X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, stratify=y, random_state=42 + X, y, test_size=0.15, stratify=y, random_state=42 ) # =============================== -# DATASET +# DATASET + DATALOADER # =============================== class ASLSequenceDataset(Dataset): def __init__(self, X, y): - self.X = torch.tensor(X, dtype=torch.float32) - self.y = torch.tensor(y, dtype=torch.long) + self.X = torch.from_numpy(X).float() + self.y = torch.from_numpy(y).long() def __len__(self): return len(self.X) @@ -150,115 +154,158 @@ class ASLSequenceDataset(Dataset): def __getitem__(self, idx): return self.X[idx], self.y[idx] -train_loader = DataLoader(ASLSequenceDataset(X_train, y_train), batch_size=64, shuffle=True, pin_memory=True) -test_loader = DataLoader(ASLSequenceDataset(X_test, y_test), batch_size=64, shuffle=False, pin_memory=True) +train_loader = DataLoader(ASLSequenceDataset(X_train, y_train), + batch_size=64, shuffle=True, num_workers=4, pin_memory=True) +test_loader = DataLoader(ASLSequenceDataset(X_test, y_test), + batch_size=96, shuffle=False, num_workers=4, pin_memory=True) # =============================== -# TRANSFORMER MODEL +# MODEL # =============================== class PositionalEncoding(nn.Module): - def __init__(self, d_model, max_len=100): + def __init__(self, d_model, max_len=128): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model)) - pe[:, 0::2] = torch.sin(position*div_term) - pe[:, 1::2] = torch.cos(position*div_term) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): - return x + self.pe[:, :x.size(1), :] + return x + self.pe[:, :x.size(1)] class TransformerASL(nn.Module): - def __init__(self, input_dim, num_classes, d_model=256, nhead=8, num_layers=4): + def __init__(self, input_dim=63, num_classes=250, d_model=192, nhead=6, num_layers=4): super().__init__() self.proj = nn.Linear(input_dim, d_model) - self.norm = nn.LayerNorm(d_model) + self.norm_in = nn.LayerNorm(d_model) + self.pos = PositionalEncoding(d_model) - encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=1024, - dropout=0.1, activation='gelu', batch_first=True, norm_first=True) + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=d_model*4, + dropout=0.15, + activation='gelu', + batch_first=True, + norm_first=True + ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - self.fc = nn.Sequential( - nn.Linear(d_model, 512), - nn.BatchNorm1d(512), - nn.GELU(), - nn.Dropout(0.3), - nn.Linear(512, num_classes) + self.head = nn.Sequential( + nn.LayerNorm(d_model), + nn.Dropout(0.25), + nn.Linear(d_model, num_classes) ) - def forward(self, x): + def forward(self, x, key_padding_mask=None): x = self.proj(x) - x = self.norm(x) + x = self.norm_in(x) x = self.pos(x) - x = self.encoder(x) # (B, T, d_model) - x = x.mean(dim=1) # temporal average - x = self.fc(x) + + x = self.encoder(x, src_key_padding_mask=key_padding_mask) + x = x.mean(dim=1) # global average pooling + x = self.head(x) return x -model = TransformerASL(input_dim=X.shape[2], num_classes=num_classes).to(device) -print("Parameters:", sum(p.numel() for p in model.parameters())) +model = TransformerASL(input_dim=63, num_classes=num_classes).to(device) +print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # =============================== -# TRAIN SETUP +# TRAINING SETUP # =============================== -criterion = nn.CrossEntropyLoss(label_smoothing=0.1) -optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4) -scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10) +criterion = nn.CrossEntropyLoss(label_smoothing=0.05) +optimizer = optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-4, betas=(0.9, 0.98)) +scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2) # =============================== -# TRAIN / EVAL FUNCTIONS +# TRAIN / EVAL # =============================== +def create_padding_mask(seq_len, max_len): + # True = ignore this position + return torch.arange(max_len, device=device)[None, :] >= seq_len[:, None] + def train_epoch(): model.train() - total, correct, loss_sum = 0, 0, 0 - for x, y in train_loader: + total_loss = 0 + correct = 0 + total = 0 + + for x, y in tqdm(train_loader, desc="Train"): x, y = x.to(device), y.to(device) + + # Very simple length heuristic (can be improved later) + real_lengths = (x.abs().sum(dim=2) > 1e-6).sum(dim=1) + mask = create_padding_mask(real_lengths, x.size(1)) + optimizer.zero_grad(set_to_none=True) - logits = model(x) + logits = model(x, key_padding_mask=mask) + loss = criterion(logits, y) loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + # STRONG clipping — very important for landmarks + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.8) + optimizer.step() - loss_sum += loss.item() - correct += (logits.argmax(1) == y).sum().item() + + total_loss += loss.item() + correct += (logits.argmax(dim=-1) == y).sum().item() total += y.size(0) - return loss_sum/len(train_loader), 100*correct/total + + # Debug exploding gradients + if torch.isnan(loss) or grad_norm > 50: + print(f"WARNING - NaN or huge grad! norm={grad_norm:.2f}") + + return total_loss / len(train_loader), correct / total * 100 @torch.no_grad() def evaluate(): model.eval() - total, correct = 0, 0 + correct = 0 + total = 0 for x, y in test_loader: x, y = x.to(device), y.to(device) - logits = model(x) - correct += (logits.argmax(1) == y).sum().item() + real_lengths = (x.abs().sum(dim=2) > 1e-6).sum(dim=1) + mask = create_padding_mask(real_lengths, x.size(1)) + + logits = model(x, key_padding_mask=mask) + correct += (logits.argmax(dim=-1) == y).sum().item() total += y.size(0) - return 100*correct/total + return correct / total * 100 # =============================== -# TRAIN LOOP +# TRAINING LOOP # =============================== best_acc = 0 -patience = 15 +patience = 18 wait = 0 -epochs = 50 +epochs = 80 for epoch in range(epochs): loss, train_acc = train_epoch() test_acc = evaluate() + + print(f"[{epoch+1:2d}/{epochs}] loss: {loss:.4f} | train: {train_acc:.2f}% | test: {test_acc:.2f}%") + scheduler.step() - print(f"Epoch {epoch+1}/{epochs} | Loss {loss:.4f} | Train {train_acc:.2f}% | Test {test_acc:.2f}%") if test_acc > best_acc: best_acc = test_acc wait = 0 - torch.save({"model": model.state_dict(), "label_encoder": le}, "asl_transformer_full.pth") + torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler, + 'label_encoder_classes': le.classes_ + }, "best_asl_transformer.pth") + print("→ Saved new best model") else: wait += 1 if wait >= patience: - print("Early stopping") + print("Early stopping triggered") break -print("Best accuracy:", best_acc) +print(f"\nBest test accuracy achieved: {best_acc:.2f}%") \ No newline at end of file