diff --git a/training.py b/training.py index aa5c323..d1dfe04 100644 --- a/training.py +++ b/training.py @@ -14,6 +14,7 @@ from sklearn.preprocessing import LabelEncoder, StandardScaler from multiprocessing import Pool, cpu_count from functools import partial from tqdm import tqdm +from collections import Counter def load_kaggle_asl_data(base_path): @@ -26,7 +27,6 @@ def load_kaggle_asl_data(base_path): def extract_hand_landmarks_from_parquet(path): try: df = pd.read_parquet(path) - # Take either left or right hand - prefer the one with more landmarks left = df[df["type"] == "left_hand"] right = df[df["type"] == "right_hand"] @@ -53,8 +53,8 @@ def extract_hand_landmarks_from_parquet(path): ]) landmarks_seq.append(lm_list) - return np.array(landmarks_seq, dtype=np.float32) # (T, 21, 3) - except Exception: + return np.array(landmarks_seq, dtype=np.float32) + except: return None @@ -63,20 +63,20 @@ def get_features_sequence(landmarks_seq, max_frames=100): return None # Center on wrist - landmarks_seq = landmarks_seq - landmarks_seq[:, 0:1, :] + landmarks_seq -= landmarks_seq[:, 0:1, :] - # Better scale: distance between index finger tip and middle finger tip + # Scale using index → middle finger tip distance (more stable than single point) scale = np.linalg.norm(landmarks_seq[:, 8] - landmarks_seq[:, 12], axis=1, keepdims=True) scale = np.maximum(scale, 1e-6) - landmarks_seq = landmarks_seq / scale + landmarks_seq /= scale - # Flatten to (T, 63) + # Flatten seq = landmarks_seq.reshape(landmarks_seq.shape[0], -1) - # Pad or truncate + # Pad / truncate if len(seq) < max_frames: pad = np.zeros((max_frames - len(seq), seq.shape[1]), dtype=np.float32) - seq = np.concatenate([seq, pad], axis=0) + seq = np.concatenate([seq, pad]) else: seq = seq[:max_frames] @@ -84,21 +84,18 @@ def get_features_sequence(landmarks_seq, max_frames=100): def process_row(row, base_path, max_frames=100): - path = os.path.join(base_path, row['path']) + path = os.path.join(base_path, row["path"]) if not os.path.exists(path): return None, None - try: - lm_seq = extract_hand_landmarks_from_parquet(path) - if lm_seq is None: + lm = extract_hand_landmarks_from_parquet(path) + if lm is None: return None, None - - feat_seq = get_features_sequence(lm_seq, max_frames) - if feat_seq is None: + feat = get_features_sequence(lm, max_frames) + if feat is None: return None, None - - return feat_seq, row['sign'] - except Exception: + return feat, row["sign"] + except: return None, None @@ -123,7 +120,7 @@ class TransformerASL(nn.Module): self.norm_in = nn.LayerNorm(d_model) self.pos = PositionalEncoding(d_model) - encoder_layer = nn.TransformerEncoderLayer( + enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, @@ -132,7 +129,7 @@ class TransformerASL(nn.Module): batch_first=True, norm_first=True ) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) self.head = nn.Sequential( nn.LayerNorm(d_model), @@ -155,7 +152,7 @@ def create_padding_mask(lengths, max_len): def main(): # =============================== - # DEVICE SETUP + # DEVICE # =============================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -163,13 +160,14 @@ def main(): print("GPU:", torch.cuda.get_device_name(0)) # =============================== - # PATHS & PARAMETERS + # CONFIG # =============================== - base_path = "asl_kaggle" # ← CHANGE THIS TO YOUR ACTUAL FOLDER + base_path = "asl_kaggle" # ← CHANGE THIS TO YOUR ACTUAL PATH max_frames = 100 + MIN_SAMPLES_PER_CLASS = 6 # ← important! prevents stratified split crash # =============================== - # DATA PROCESSING + # DATA LOADING & PROCESSING # =============================== print("Loading metadata...") train_df, sign_to_idx = load_kaggle_asl_data(base_path) @@ -184,25 +182,25 @@ def main(): rows ), total=len(rows), - desc="Processing" + desc="Extracting landmarks" )) - X, y = [], [] + X_list, y_list = [], [] for feat, sign in results: if feat is not None: - X.append(feat) - y.append(sign) + X_list.append(feat) + y_list.append(sign) - if not X: - print("No valid sequences found!") + if not X_list: + print("No valid sequences found. Check parquet files / paths.") return - X = np.stack(X) - print(f"Loaded {len(X)} valid samples | shape: {X.shape}") + X = np.stack(X_list) + print(f"Loaded {len(X)} valid sequences | shape: {X.shape}") - # Global normalization - very important! + # Global normalization (very important for stability) print("Before global norm → mean:", X.mean(), "std:", X.std()) - X = np.clip(X, -5.0, 5.0) # prevent crazy outliers + X = np.clip(X, -5.0, 5.0) mean = X.mean(axis=(0, 1), keepdims=True) std = X.std(axis=(0, 1), keepdims=True) + 1e-8 X = (X - mean) / std @@ -212,15 +210,30 @@ def main(): # LABELS # =============================== le = LabelEncoder() + y = le.fit_transform(y_list) + + # Remove classes with too few samples (prevents stratify error) + counts = Counter(y) + valid_classes = [cls for cls, cnt in counts.items() if cnt >= MIN_SAMPLES_PER_CLASS] + + mask = np.isin(y, valid_classes) + X = X[mask] + y = y[mask] + + # Re-encode labels consecutively (0,1,2,... no gaps) + le = LabelEncoder() y = le.fit_transform(y) - num_classes = len(le.classes_) - print(f"Number of classes: {num_classes}") + + print(f"After filtering: {len(X)} samples remain | {len(le.classes_)} classes") # =============================== # SPLIT # =============================== X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.15, stratify=y, random_state=42 + X, y, + test_size=0.15, + stratify=y, # should be safe now + random_state=42 ) # =============================== @@ -258,7 +271,7 @@ def main(): # =============================== model = TransformerASL( input_dim=63, - num_classes=num_classes, + num_classes=len(le.classes_), d_model=192, nhead=6, num_layers=4 @@ -274,18 +287,15 @@ def main(): scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10) # =============================== - # TRAIN / EVAL FUNCTIONS + # TRAIN / EVAL # =============================== def train_epoch(): model.train() total_loss = 0 - correct = 0 - total = 0 + correct = total = 0 - for x, y in tqdm(train_loader, desc="Training"): + for x, y in tqdm(train_loader, desc="Train"): x, y = x.to(device), y.to(device) - - # Rough length estimation lengths = (x.abs().sum(dim=2) > 1e-5).sum(dim=1) mask = create_padding_mask(lengths, x.size(1)) @@ -297,13 +307,10 @@ def main(): grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.8) - if torch.isnan(loss) or grad_norm > 20: - print(f"Warning - large grad or NaN! norm = {grad_norm:.2f}") - optimizer.step() total_loss += loss.item() - correct += (logits.argmax(dim=-1) == y).sum().item() + correct += (logits.argmax(-1) == y).sum().item() total += y.size(0) return total_loss / len(train_loader), correct / total * 100 @@ -311,31 +318,30 @@ def main(): @torch.no_grad() def evaluate(): model.eval() - correct = 0 - total = 0 + correct = total = 0 for x, y in test_loader: x, y = x.to(device), y.to(device) lengths = (x.abs().sum(dim=2) > 1e-5).sum(dim=1) mask = create_padding_mask(lengths, x.size(1)) logits = model(x, key_padding_mask=mask) - correct += (logits.argmax(dim=-1) == y).sum().item() + correct += (logits.argmax(-1) == y).sum().item() total += y.size(0) - return correct / total * 100 if total > 0 else 0 + return correct / total * 100 if total > 0 else 0.0 # =============================== # TRAINING LOOP # =============================== - best_acc = 0 + best_acc = 0.0 patience = 15 wait = 0 - epochs = 60 + epochs = 70 for epoch in range(epochs): loss, train_acc = train_epoch() test_acc = evaluate() - print(f"Epoch {epoch + 1:2d}/{epochs} | Loss: {loss:.4f} | Train: {train_acc:.2f}% | Test: {test_acc:.2f}%") + print(f"[{epoch + 1:2d}/{epochs}] loss: {loss:.4f} | train: {train_acc:.2f}% | test: {test_acc:.2f}%") scheduler.step() @@ -345,18 +351,18 @@ def main(): torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), - 'label_encoder': le.classes_, - 'epoch': epoch, - 'acc': best_acc + 'label_encoder_classes': le.classes_, + 'acc': best_acc, + 'epoch': epoch }, "best_asl_transformer.pth") - print(" → New best model saved") + print(" → New best saved") else: wait += 1 if wait >= patience: - print("Early stopping triggered") + print("Early stopping") break - print(f"\nTraining finished. Best test accuracy: {best_acc:.2f}%") + print(f"\nBest test accuracy reached: {best_acc:.2f}%") if __name__ == '__main__':