grok lock in

This commit is contained in:
2026-01-10 23:04:48 -06:00
parent c209e036cb
commit f66049f6ea

View File

@@ -4,23 +4,21 @@
import os import os
import json import json
import math import math
import time
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import StandardScaler
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
from functools import partial from functools import partial
from tqdm import tqdm
# =============================== # ===============================
# GPU SETUP # DEVICE
# =============================== # ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
@@ -29,7 +27,7 @@ if device.type == "cuda":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
# =============================== # ===============================
# DATA LOADING & FEATURE EXTRACTION # DATA LOADING
# =============================== # ===============================
def load_kaggle_asl_data(base_path): def load_kaggle_asl_data(base_path):
train_df = pd.read_csv(os.path.join(base_path, "train.csv")) train_df = pd.read_csv(os.path.join(base_path, "train.csv"))
@@ -38,19 +36,12 @@ def load_kaggle_asl_data(base_path):
return train_df, sign_to_idx return train_df, sign_to_idx
def extract_hand_landmarks_from_parquet(path): def extract_hand_landmarks_from_parquet(path):
try:
df = pd.read_parquet(path) df = pd.read_parquet(path)
left = df[df["type"] == "left_hand"] hand = df[df["type"].isin(["left_hand", "right_hand"])]
right = df[df["type"] == "right_hand"] if len(hand) == 0:
hand = None
if len(left) > 0:
hand = left
elif len(right) > 0:
hand = right
else:
return None return None
# Keep all frames
frames = sorted(hand['frame'].unique()) frames = sorted(hand['frame'].unique())
landmarks_seq = [] landmarks_seq = []
@@ -62,87 +53,100 @@ def extract_hand_landmarks_from_parquet(path):
if len(lm) == 0: if len(lm) == 0:
lm_list.append([0.0, 0.0, 0.0]) lm_list.append([0.0, 0.0, 0.0])
else: else:
lm_list.append([ lm_list.append([lm['x'].values[0], lm['y'].values[0], lm['z'].values[0]])
lm['x'].mean(),
lm['y'].mean(),
lm['z'].mean()
])
landmarks_seq.append(lm_list) landmarks_seq.append(lm_list)
return np.array(landmarks_seq, dtype=np.float32) # (T, 21, 3) return np.array(landmarks_seq, dtype=np.float32) # (T, 21, 3)
except:
def get_features_sequence(landmarks_seq, max_frames=100):
if landmarks_seq is None:
return None return None
# Center on wrist
points = landmarks_seq - landmarks_seq[:, 0:1, :]
scale = np.linalg.norm(points[:, 9, :], axis=1, keepdims=True)
scale[scale < 1e-6] = 1.0
points /= scale[:, np.newaxis, :]
# Flatten per frame
frames = points.reshape(points.shape[0], -1)
# Pad or truncate
if frames.shape[0] < max_frames:
pad = np.zeros((max_frames - frames.shape[0], frames.shape[1]), dtype=np.float32)
frames = np.vstack([frames, pad])
else:
frames = frames[:max_frames]
return frames # (max_frames, 63)
def process_row(row, base_path, max_frames=100): def get_features_sequence(landmarks_seq, max_frames=96):
if landmarks_seq is None or len(landmarks_seq) == 0:
return None
# Center on wrist (landmark 0)
landmarks_seq = landmarks_seq - landmarks_seq[:, 0:1, :]
# Rough scale normalization (using index finger length as reference)
scale = np.linalg.norm(landmarks_seq[:, 8] - landmarks_seq[:, 5], axis=1, keepdims=True)
scale = np.maximum(scale, 1e-6)
landmarks_seq /= scale
# Flatten → (T, 63)
seq = landmarks_seq.reshape(landmarks_seq.shape[0], -1)
# Pad / truncate
if len(seq) < max_frames:
pad = np.zeros((max_frames - len(seq), seq.shape[1]), dtype=np.float32)
seq = np.concatenate([seq, pad], axis=0)
else:
seq = seq[:max_frames]
return seq.astype(np.float32)
def process_row(row, base_path, max_frames=96):
path = os.path.join(base_path, row['path']) path = os.path.join(base_path, row['path'])
if not os.path.exists(path): if not os.path.exists(path):
return None, None return None, None
try: lm = extract_hand_landmarks_from_parquet(path)
lm_seq = extract_hand_landmarks_from_parquet(path) feat = get_features_sequence(lm, max_frames)
feat_seq = get_features_sequence(lm_seq, max_frames) if feat is None:
return feat_seq, row['sign']
except:
return None, None return None, None
return feat, row['sign']
# =============================== # ===============================
# LOAD + PROCESS DATA # LOAD & PROCESS (with progress)
# =============================== # ===============================
base_path = "asl_kaggle" base_path = "asl_kaggle" # ← change if needed
train_df, sign_to_idx = load_kaggle_asl_data(base_path) train_df, sign_to_idx = load_kaggle_asl_data(base_path)
print("Processing videos...")
rows = [row for _, row in train_df.iterrows()] rows = [row for _, row in train_df.iterrows()]
X, y = [], []
func = partial(process_row, base_path=base_path, max_frames=100)
with Pool(cpu_count()) as pool: with Pool(cpu_count()) as pool:
for feat_seq, sign in pool.map(func, rows): results = list(tqdm(pool.imap(
if feat_seq is not None: partial(process_row, base_path=base_path, max_frames=96),
X.append(feat_seq) rows
), total=len(rows)))
X, y = [], []
for feat, sign in results:
if feat is not None:
X.append(feat)
y.append(sign) y.append(sign)
X = np.stack(X) # (N, T, 63) X = np.stack(X) # (N, T, 63)
y = np.array(y) print(f"Loaded {len(X)} valid samples | shape: {X.shape}")
print("Samples:", len(X))
print("Sequence shape:", X.shape[1:]) # Global normalization (very important!)
scaler = StandardScaler()
X_reshaped = X.reshape(-1, X.shape[-1])
X_reshaped = scaler.fit_transform(X_reshaped)
X = X_reshaped.reshape(X.shape)
# =============================== # ===============================
# LABEL ENCODING # LABELS
# =============================== # ===============================
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder() le = LabelEncoder()
y = le.fit_transform(y) y = le.fit_transform(y)
num_classes = len(le.classes_) num_classes = len(le.classes_)
print("Num classes:", num_classes) print(f"Classes: {num_classes}")
# =============================== # ===============================
# SPLIT # SPLIT
# =============================== # ===============================
X_train, X_test, y_train, y_test = train_test_split( X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42 X, y, test_size=0.15, stratify=y, random_state=42
) )
# =============================== # ===============================
# DATASET # DATASET + DATALOADER
# =============================== # ===============================
class ASLSequenceDataset(Dataset): class ASLSequenceDataset(Dataset):
def __init__(self, X, y): def __init__(self, X, y):
self.X = torch.tensor(X, dtype=torch.float32) self.X = torch.from_numpy(X).float()
self.y = torch.tensor(y, dtype=torch.long) self.y = torch.from_numpy(y).long()
def __len__(self): def __len__(self):
return len(self.X) return len(self.X)
@@ -150,14 +154,16 @@ class ASLSequenceDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.X[idx], self.y[idx] return self.X[idx], self.y[idx]
train_loader = DataLoader(ASLSequenceDataset(X_train, y_train), batch_size=64, shuffle=True, pin_memory=True) train_loader = DataLoader(ASLSequenceDataset(X_train, y_train),
test_loader = DataLoader(ASLSequenceDataset(X_test, y_test), batch_size=64, shuffle=False, pin_memory=True) batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(ASLSequenceDataset(X_test, y_test),
batch_size=96, shuffle=False, num_workers=4, pin_memory=True)
# =============================== # ===============================
# TRANSFORMER MODEL # MODEL
# =============================== # ===============================
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100): def __init__(self, d_model, max_len=128):
super().__init__() super().__init__()
pe = torch.zeros(max_len, d_model) pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
@@ -167,98 +173,139 @@ class PositionalEncoding(nn.Module):
self.register_buffer('pe', pe.unsqueeze(0)) self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x): def forward(self, x):
return x + self.pe[:, :x.size(1), :] return x + self.pe[:, :x.size(1)]
class TransformerASL(nn.Module): class TransformerASL(nn.Module):
def __init__(self, input_dim, num_classes, d_model=256, nhead=8, num_layers=4): def __init__(self, input_dim=63, num_classes=250, d_model=192, nhead=6, num_layers=4):
super().__init__() super().__init__()
self.proj = nn.Linear(input_dim, d_model) self.proj = nn.Linear(input_dim, d_model)
self.norm = nn.LayerNorm(d_model) self.norm_in = nn.LayerNorm(d_model)
self.pos = PositionalEncoding(d_model) self.pos = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=1024, encoder_layer = nn.TransformerEncoderLayer(
dropout=0.1, activation='gelu', batch_first=True, norm_first=True) d_model=d_model,
nhead=nhead,
dim_feedforward=d_model*4,
dropout=0.15,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Sequential( self.head = nn.Sequential(
nn.Linear(d_model, 512), nn.LayerNorm(d_model),
nn.BatchNorm1d(512), nn.Dropout(0.25),
nn.GELU(), nn.Linear(d_model, num_classes)
nn.Dropout(0.3),
nn.Linear(512, num_classes)
) )
def forward(self, x): def forward(self, x, key_padding_mask=None):
x = self.proj(x) x = self.proj(x)
x = self.norm(x) x = self.norm_in(x)
x = self.pos(x) x = self.pos(x)
x = self.encoder(x) # (B, T, d_model)
x = x.mean(dim=1) # temporal average x = self.encoder(x, src_key_padding_mask=key_padding_mask)
x = self.fc(x) x = x.mean(dim=1) # global average pooling
x = self.head(x)
return x return x
model = TransformerASL(input_dim=X.shape[2], num_classes=num_classes).to(device) model = TransformerASL(input_dim=63, num_classes=num_classes).to(device)
print("Parameters:", sum(p.numel() for p in model.parameters())) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# =============================== # ===============================
# TRAIN SETUP # TRAINING SETUP
# =============================== # ===============================
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4) optimizer = optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-4, betas=(0.9, 0.98))
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2)
# =============================== # ===============================
# TRAIN / EVAL FUNCTIONS # TRAIN / EVAL
# =============================== # ===============================
def create_padding_mask(seq_len, max_len):
# True = ignore this position
return torch.arange(max_len, device=device)[None, :] >= seq_len[:, None]
def train_epoch(): def train_epoch():
model.train() model.train()
total, correct, loss_sum = 0, 0, 0 total_loss = 0
for x, y in train_loader: correct = 0
total = 0
for x, y in tqdm(train_loader, desc="Train"):
x, y = x.to(device), y.to(device) x, y = x.to(device), y.to(device)
# Very simple length heuristic (can be improved later)
real_lengths = (x.abs().sum(dim=2) > 1e-6).sum(dim=1)
mask = create_padding_mask(real_lengths, x.size(1))
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
logits = model(x) logits = model(x, key_padding_mask=mask)
loss = criterion(logits, y) loss = criterion(logits, y)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# STRONG clipping — very important for landmarks
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.8)
optimizer.step() optimizer.step()
loss_sum += loss.item()
correct += (logits.argmax(1) == y).sum().item() total_loss += loss.item()
correct += (logits.argmax(dim=-1) == y).sum().item()
total += y.size(0) total += y.size(0)
return loss_sum/len(train_loader), 100*correct/total
# Debug exploding gradients
if torch.isnan(loss) or grad_norm > 50:
print(f"WARNING - NaN or huge grad! norm={grad_norm:.2f}")
return total_loss / len(train_loader), correct / total * 100
@torch.no_grad() @torch.no_grad()
def evaluate(): def evaluate():
model.eval() model.eval()
total, correct = 0, 0 correct = 0
total = 0
for x, y in test_loader: for x, y in test_loader:
x, y = x.to(device), y.to(device) x, y = x.to(device), y.to(device)
logits = model(x) real_lengths = (x.abs().sum(dim=2) > 1e-6).sum(dim=1)
correct += (logits.argmax(1) == y).sum().item() mask = create_padding_mask(real_lengths, x.size(1))
logits = model(x, key_padding_mask=mask)
correct += (logits.argmax(dim=-1) == y).sum().item()
total += y.size(0) total += y.size(0)
return 100*correct/total return correct / total * 100
# =============================== # ===============================
# TRAIN LOOP # TRAINING LOOP
# =============================== # ===============================
best_acc = 0 best_acc = 0
patience = 15 patience = 18
wait = 0 wait = 0
epochs = 50 epochs = 80
for epoch in range(epochs): for epoch in range(epochs):
loss, train_acc = train_epoch() loss, train_acc = train_epoch()
test_acc = evaluate() test_acc = evaluate()
print(f"[{epoch+1:2d}/{epochs}] loss: {loss:.4f} | train: {train_acc:.2f}% | test: {test_acc:.2f}%")
scheduler.step() scheduler.step()
print(f"Epoch {epoch+1}/{epochs} | Loss {loss:.4f} | Train {train_acc:.2f}% | Test {test_acc:.2f}%")
if test_acc > best_acc: if test_acc > best_acc:
best_acc = test_acc best_acc = test_acc
wait = 0 wait = 0
torch.save({"model": model.state_dict(), "label_encoder": le}, "asl_transformer_full.pth") torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scaler': scaler,
'label_encoder_classes': le.classes_
}, "best_asl_transformer.pth")
print("→ Saved new best model")
else: else:
wait += 1 wait += 1
if wait >= patience: if wait >= patience:
print("Early stopping") print("Early stopping triggered")
break break
print("Best accuracy:", best_acc) print(f"\nBest test accuracy achieved: {best_acc:.2f}%")