diff --git a/.gitignore b/.gitignore index ddebd95..f5e269e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ asl_kaggle/ hand_landmarker.task asl-dataset.zip -asl-signs.zip \ No newline at end of file +asl-signs.zip +best_asl_transformer.pth \ No newline at end of file diff --git a/test.py b/test.py index b759c50..b1067ab 100644 --- a/test.py +++ b/test.py @@ -1,170 +1,208 @@ -import mediapipe as mp import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import math +from collections import deque, Counter +import pandas as pd # ← added for rebuilding labels + +# Modern MediaPipe Tasks API (no legacy solutions module) +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision + +# PyTorch ≥ 2.6 checkpoint loading fix +import numpy as np +import numpy.core.multiarray +import numpy.dtypes + +torch.serialization.add_safe_globals([ + np.ndarray, + np.dtype, + np.dtypes.Int64DType, + np.core.multiarray._reconstruct +]) -# Positional Encoding +# =============================== +# MODEL DEFINITION +# =============================== class PositionalEncoding(nn.Module): - def __init__(self, d_model, max_len=100): - super(PositionalEncoding, self).__init__() - + def __init__(self, d_model, max_len=128): + super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - - pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) + self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): - return x + self.pe[:, :x.size(1), :] + return x + self.pe[:, :x.size(1)] -# Model architecture -class TransformerCNN_ASL(nn.Module): - def __init__(self, input_dim=77, num_classes=250, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048): - super(TransformerCNN_ASL, self).__init__() +class TransformerASL(nn.Module): + def __init__(self, input_dim, num_classes, d_model=256, nhead=8, num_layers=4): + super().__init__() + self.proj = nn.Linear(input_dim, d_model) + self.norm_in = nn.LayerNorm(d_model) + self.pos = PositionalEncoding(d_model, max_len=128) - self.input_dim = input_dim - self.d_model = d_model - - # Input projection - self.input_projection = nn.Linear(input_dim, d_model) - self.input_norm = nn.LayerNorm(d_model) - - # Positional encoding - self.pos_encoder = PositionalEncoding(d_model, max_len=100) - - # Transformer Encoder with Self-Attention - encoder_layer = nn.TransformerEncoderLayer( + enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=0.1, + dim_feedforward=d_model * 4, + dropout=0.15, activation='gelu', batch_first=True, norm_first=True ) - self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) - # CNN Blocks for pattern detection - self.conv1 = nn.Conv1d(d_model, 1024, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm1d(1024) - self.pool1 = nn.MaxPool1d(2) - self.dropout1 = nn.Dropout(0.3) + self.head = nn.Sequential( + nn.LayerNorm(d_model), + nn.Dropout(0.25), + nn.Linear(d_model, num_classes) + ) - self.conv2 = nn.Conv1d(1024, 2048, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm1d(2048) - self.pool2 = nn.MaxPool1d(2) - self.dropout2 = nn.Dropout(0.3) - - self.conv3 = nn.Conv1d(2048, 4096, kernel_size=3, padding=1) - self.bn3 = nn.BatchNorm1d(4096) - self.pool3 = nn.AdaptiveMaxPool1d(1) # Global pooling - self.dropout3 = nn.Dropout(0.4) - - # Fully connected layers - self.fc1 = nn.Linear(4096, 4096) - self.bn_fc1 = nn.BatchNorm1d(4096) - self.dropout_fc1 = nn.Dropout(0.5) - - self.fc2 = nn.Linear(4096, 2048) - self.bn_fc2 = nn.BatchNorm1d(2048) - self.dropout_fc2 = nn.Dropout(0.4) - - self.fc3 = nn.Linear(2048, 1024) - self.bn_fc3 = nn.BatchNorm1d(1024) - self.dropout_fc3 = nn.Dropout(0.3) - - self.fc4 = nn.Linear(1024, num_classes) - - def forward(self, x): - batch_size = x.size(0) - - # Project to d_model - x = self.input_projection(x) - x = self.input_norm(x) - x = x.unsqueeze(1) - - # Add positional encoding - x = self.pos_encoder(x) - - # Transformer encoder with self-attention - x = self.transformer_encoder(x) - - # Reshape for CNN - x = x.permute(0, 2, 1) - - # CNN pattern detection - x = F.gelu(self.bn1(self.conv1(x))) - x = self.pool1(x) - x = self.dropout1(x) - - x = F.gelu(self.bn2(self.conv2(x))) - x = self.pool2(x) - x = self.dropout2(x) - - x = F.gelu(self.bn3(self.conv3(x))) - x = self.pool3(x) - x = self.dropout3(x) - - # Flatten - x = x.view(batch_size, -1) - - # Fully connected layers - x = F.gelu(self.bn_fc1(self.fc1(x))) - x = self.dropout_fc1(x) - - x = F.gelu(self.bn_fc2(self.fc2(x))) - x = self.dropout_fc2(x) - - x = F.gelu(self.bn_fc3(self.fc3(x))) - x = self.dropout_fc3(x) - - x = self.fc4(x) - - return x + def forward(self, x, key_padding_mask=None): + x = self.proj(x) + x = self.norm_in(x) + x = self.pos(x) + x = self.encoder(x, src_key_padding_mask=key_padding_mask) + x = x.mean(dim=1) + return self.head(x) -# Load the trained model -print("Loading model...") -checkpoint = torch.load('asl_kaggle_transformer.pth', map_location='cpu') -label_encoder = checkpoint['label_encoder'] -num_classes = checkpoint['num_classes'] -input_dim = checkpoint['input_dim'] -config = checkpoint['model_config'] +# =============================== +# FEATURE EXTRACTION +# =============================== +def get_features_sequence(landmarks_seq, max_frames=100): + if landmarks_seq is None or len(landmarks_seq) == 0: + return None, None -model = TransformerCNN_ASL( - input_dim=input_dim, - num_classes=num_classes, - d_model=config['d_model'], - nhead=config['nhead'], - num_layers=config['num_layers'], - dim_feedforward=config['dim_feedforward'] + wrist = landmarks_seq[:, 0:1, :] + landmarks_seq = landmarks_seq - wrist + + scale = np.linalg.norm(landmarks_seq[:, 9], axis=1, keepdims=True) + scale = np.maximum(scale, 1e-6) + landmarks_seq = landmarks_seq / scale[:, :, np.newaxis] + + landmarks_seq = np.nan_to_num(landmarks_seq, nan=0.0, posinf=0.0, neginf=0.0) + landmarks_seq = np.clip(landmarks_seq, -10, 10) + + tips = [4, 8, 12, 16, 20] + bases = [1, 5, 9, 13, 17] + curls = [np.linalg.norm(landmarks_seq[:, t] - landmarks_seq[:, b], axis=1) + for b, t in zip(bases, tips)] + curl_features = np.stack(curls, axis=1) + + deltas = np.zeros_like(landmarks_seq) + if len(landmarks_seq) > 1: + deltas[1:] = landmarks_seq[1:] - landmarks_seq[:-1] + + pos_flat = landmarks_seq.reshape(len(landmarks_seq), -1) + delta_flat = deltas.reshape(len(landmarks_seq), -1) + seq = np.concatenate([pos_flat, delta_flat, curl_features], axis=1) + + T, F = seq.shape + if T < max_frames: + pad = np.zeros((max_frames - T, F), dtype=np.float32) + seq_padded = np.concatenate([seq, pad], axis=0) + else: + seq_padded = seq[:max_frames] + + mask = np.zeros(max_frames, dtype=bool) + mask[:min(T, max_frames)] = True + + return seq_padded.astype(np.float32), mask + + +# =============================== +# MANUAL DRAWING FUNCTION +# =============================== +HAND_CONNECTIONS = [ + (0, 1), (1, 2), (2, 3), (3, 4), + (0, 5), (5, 6), (6, 7), (7, 8), + (0, 9), (9, 10), (10, 11), (11, 12), + (0, 13), (13, 14), (14, 15), (15, 16), + (0, 17), (17, 18), (18, 19), (19, 20), + (5, 9), (9, 13), (13, 17) +] + + +def draw_hand_landmarks(image, landmarks_list): + h, w = image.shape[:2] + + # Draw connections (blue lines) + for start_idx, end_idx in HAND_CONNECTIONS: + start = landmarks_list[start_idx] + end = landmarks_list[end_idx] + start_pt = (int(start.x * w), int(start.y * h)) + end_pt = (int(end.x * w), int(end.y * h)) + cv2.line(image, start_pt, end_pt, (255, 0, 0), 2) + + # Draw landmarks (green circles) + for lm in landmarks_list: + x = int(lm.x * w) + y = int(lm.y * h) + cv2.circle(image, (x, y), 5, (0, 255, 0), -1) + + +# =============================== +# MAIN PROGRAM +# =============================== +print("Loading trained model...") +checkpoint = torch.load("best_asl_transformer.pth", map_location="cpu") + +model = TransformerASL( + input_dim=checkpoint['input_dim'], + num_classes=checkpoint['num_classes'], + d_model=checkpoint['d_model'], + nhead=checkpoint['nhead'], + num_layers=checkpoint['num_layers'] ) -model.load_state_dict(checkpoint['model_state_dict']) + +model.load_state_dict(checkpoint['model']) model.eval() -total_params = sum(p.numel() for p in model.parameters()) -print(f"Loaded Transformer+CNN model") -print(f"Total parameters: {total_params:,}") -print(f"Number of ASL signs: {num_classes}") -print(f"Sample signs: {label_encoder.classes_[:10]}") +# ─── FIX: Rebuild real sign names from train.csv ───────────────────── +print("\n" + "=" * 70) +print("Rebuilding sign name mapping from train.csv...") -# Setup MediaPipe -BaseOptions = mp.tasks.BaseOptions -HandLandmarker = mp.tasks.vision.HandLandmarker -HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions -VisionRunningMode = mp.tasks.vision.RunningMode +try: + # CHANGE THIS PATH to where your train.csv actually is + train_df = pd.read_csv("asl_kaggle/train.csv") # ← most important line! + + # Get unique signs, sorted (same order LabelEncoder usually uses) + real_signs = sorted(train_df['sign'].unique()) + + # Use real sign names instead of numbers + label_encoder_classes = real_signs + + print("SUCCESS! Loaded real sign names") + print("Number of classes:", len(real_signs)) + print("First 15 signs:", real_signs[:15]) + print("=" * 70 + "\n") + +except Exception as e: + print("ERROR loading train.csv:", e) + print("Falling back to numeric labels (you'll see numbers instead of words)") + label_encoder_classes = checkpoint['label_encoder_classes'] + print("First 15 (still numbers):", label_encoder_classes[:15]) + print("=" * 70 + "\n") + +# MediaPipe Tasks setup +BaseOptions = python.BaseOptions +HandLandmarker = vision.HandLandmarker +HandLandmarkerOptions = vision.HandLandmarkerOptions +VisionRunningMode = vision.RunningMode + +MODEL_PATH = "hand_landmarker.task" # Make sure this file is in the folder options = HandLandmarkerOptions( - base_options=BaseOptions(model_asset_path='hand_landmarker.task'), + base_options=BaseOptions(model_asset_path=MODEL_PATH), running_mode=VisionRunningMode.VIDEO, num_hands=1, min_hand_detection_confidence=0.5, @@ -174,270 +212,118 @@ options = HandLandmarkerOptions( landmarker = HandLandmarker.create_from_options(options) +# Buffers +MAX_FRAMES = 100 +sequence_buffer = [] +prediction_buffer = deque(maxlen=15) -def get_optimized_features(hand_landmarks): - """ - Extract optimally normalized relative coordinates from MediaPipe landmarks - Returns 77 features - """ - # Extract raw coordinates - points = np.array([[lm.x, lm.y, lm.z] for lm in hand_landmarks], dtype=np.float32) - - # Step 1: Translation invariance - center on wrist - wrist = points[0].copy() - points_centered = points - wrist - - # Step 2: Scale invariance - normalize by palm size - palm_size = np.linalg.norm(points[9] - points[0]) # wrist to middle finger base - if palm_size < 1e-6: - palm_size = 1.0 - points_normalized = points_centered / palm_size - - # Step 3: Standardization - mean = np.mean(points_normalized, axis=0) - std = np.std(points_normalized, axis=0) + 1e-8 - points_standardized = (points_normalized - mean) / std - - # Flatten base features (63 features) - features = points_standardized.flatten() - - # Step 4: Derived features - finger_tips = [4, 8, 12, 16, 20] # Thumb, Index, Middle, Ring, Pinky - - # Distances between consecutive fingertips (4 distances) - tip_distances = [] - for i in range(len(finger_tips) - 1): - dist = np.linalg.norm(points_normalized[finger_tips[i]] - points_normalized[finger_tips[i + 1]]) - tip_distances.append(dist) - - # Distance of each fingertip from palm center (5 distances) - palm_center = np.mean(points_normalized[[0, 5, 9, 13, 17]], axis=0) - tip_to_palm = [] - for tip in finger_tips: - dist = np.linalg.norm(points_normalized[tip] - palm_center) - tip_to_palm.append(dist) - - # Finger curl indicators (5 curls) - finger_curls = [] - finger_bases = [1, 5, 9, 13, 17] - for base, tip in zip(finger_bases, finger_tips): - curl = np.linalg.norm(points_normalized[tip] - points_normalized[base]) - finger_curls.append(curl) - - # Combine all features: 63 + 4 + 5 + 5 = 77 - all_features = np.concatenate([ - features, - tip_distances, - tip_to_palm, - finger_curls - ]) - - return all_features.astype(np.float32) - - -# Initialize webcam cap = cv2.VideoCapture(0) - if not cap.isOpened(): - print("Error: Cannot open webcam") + print("Cannot open webcam") exit() -# Set camera resolution for better performance -cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) -cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) - -frame_count = 0 -fps_counter = 0 -fps_start_time = cv2.getTickCount() -current_fps = 0 - -# Prediction smoothing buffer -from collections import deque - -prediction_buffer = deque(maxlen=10) - -print("\n" + "=" * 60) -print("ASL Recognition - Transformer+CNN Model") -print("=" * 60) -print("Controls:") -print(" ESC - Exit") -print(" SPACE - Clear prediction buffer") -print(" 'h' - Toggle hand landmarks visibility") -print("=" * 60 + "\n") +print("\nASL Recognition running - Press ESC to quit") +print("Controls: ESC = quit | SPACE = clear | H = toggle landmarks\n") show_landmarks = True +frame_timestamp_ms = 0 -with torch.no_grad(): - while True: - success, image = cap.read() - if not success: - print("Failed to read frame from webcam") - break +while cap.isOpened(): + success, image = cap.read() + if not success: + break - # Flip image horizontally for mirror view - image = cv2.flip(image, 1) + image = cv2.flip(image, 1) + h, w = image.shape[:2] - # Convert to MediaPipe format - mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + frame_timestamp_ms += 33 - # Detect hands - results = landmarker.detect_for_video(mp_image, frame_count) - frame_count += 1 + results = landmarker.detect_for_video(mp_image, frame_timestamp_ms) - # Calculate FPS - fps_counter += 1 - if fps_counter >= 30: - fps_end_time = cv2.getTickCount() - time_diff = (fps_end_time - fps_start_time) / cv2.getTickFrequency() - current_fps = fps_counter / time_diff - fps_counter = 0 - fps_start_time = cv2.getTickCount() + overlay = image.copy() + cv2.rectangle(overlay, (10, 10), (520, 340), (0, 0, 0), -1) + cv2.addWeighted(overlay, 0.65, image, 0.35, 0, image) - # Process hand landmarks if detected - if results.hand_landmarks and len(results.hand_landmarks) > 0: - hand_landmarks = results.hand_landmarks[0] + if results.hand_landmarks: + hand_landmarks_list = results.hand_landmarks[0] - # Draw hand landmarks if enabled - if show_landmarks: - # Draw connections - connections = [ - (0, 1), (1, 2), (2, 3), (3, 4), # Thumb - (0, 5), (5, 6), (6, 7), (7, 8), # Index - (0, 9), (9, 10), (10, 11), (11, 12), # Middle - (0, 13), (13, 14), (14, 15), (15, 16), # Ring - (0, 17), (17, 18), (18, 19), (19, 20), # Pinky - (5, 9), (9, 13), (13, 17) # Palm - ] + if show_landmarks: + draw_hand_landmarks(image, hand_landmarks_list) - # Get image dimensions - h, w = image.shape[:2] + current_frame = np.array( + [[lm.x, lm.y, lm.z] for lm in hand_landmarks_list], + dtype=np.float32 + ) - # Draw connections - for connection in connections: - start_idx, end_idx = connection - start = hand_landmarks[start_idx] - end = hand_landmarks[end_idx] + sequence_buffer.append(current_frame) + if len(sequence_buffer) > MAX_FRAMES: + sequence_buffer = sequence_buffer[-MAX_FRAMES:] - start_point = (int(start.x * w), int(start.y * h)) - end_point = (int(end.x * w), int(end.y * h)) + if len(sequence_buffer) >= 10: + seq_np = np.array(sequence_buffer) + feats, mask = get_features_sequence(seq_np, MAX_FRAMES) - cv2.line(image, start_point, end_point, (0, 255, 0), 2) + if feats is not None: + x = torch.from_numpy(feats).float().unsqueeze(0) + key_padding_mask = torch.from_numpy(~mask).unsqueeze(0) - # Draw landmarks - for i, landmark in enumerate(hand_landmarks): - x = int(landmark.x * w) - y = int(landmark.y * h) + with torch.no_grad(): + logits = model(x, key_padding_mask=key_padding_mask) + probs = F.softmax(logits, dim=-1)[0] + pred_idx = torch.argmax(probs).item() + conf = probs[pred_idx].item() - # Different colors for different parts - if i == 0: # Wrist - color = (255, 0, 0) - radius = 8 - elif i in [4, 8, 12, 16, 20]: # Fingertips - color = (0, 0, 255) - radius = 6 - else: - color = (0, 255, 0) - radius = 4 + # Now using real sign names! + sign = label_encoder_classes[pred_idx] - cv2.circle(image, (x, y), radius, color, -1) - cv2.circle(image, (x, y), radius + 2, (255, 255, 255), 1) + if conf > 0.40: + prediction_buffer.append(sign) - # Extract features - features = get_optimized_features(hand_landmarks) + final_sign = sign + final_conf = conf + if len(prediction_buffer) >= 6: + final_sign = Counter(prediction_buffer).most_common(1)[0][0] + try: + final_conf = probs[label_encoder_classes.index(final_sign)].item() + except: + pass - # Make prediction - input_tensor = torch.FloatTensor(features).unsqueeze(0) - output = model(input_tensor) - probabilities = torch.softmax(output, dim=1)[0] + color = (0, 255, 100) if final_conf > 0.75 else (0, 220, 220) + cv2.putText(image, f"Sign: {final_sign}", (25, 60), + cv2.FONT_HERSHEY_SIMPLEX, 1.8, color, 4) + cv2.putText(image, f"Conf: {final_conf:.1%}", (25, 110), + cv2.FONT_HERSHEY_SIMPLEX, 0.9, (220, 220, 220), 2) - # Get top prediction - predicted_idx = torch.argmax(probabilities).item() - confidence = probabilities[predicted_idx].item() - predicted_sign = label_encoder.inverse_transform([predicted_idx])[0] + top3_p, top3_i = torch.topk(probs, 3) + for i, (p, idx) in enumerate(zip(top3_p, top3_i)): + s = label_encoder_classes[idx.item()] + cv2.putText(image, f"{i + 1}. {s:<18} {p:.1%}", + (25, 155 + i * 40), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (200, 200, 200), 2) - # Add to buffer for smoothing - if confidence > 0.3: # Only add if confident enough - prediction_buffer.append(predicted_sign) + else: + if len(sequence_buffer) < 25: + sequence_buffer.clear() + cv2.putText(image, "No hand detected", (25, 60), + cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3) - # Get smoothed prediction (most common in buffer) - if len(prediction_buffer) >= 5: - from collections import Counter + cv2.putText(image, "ESC:quit SPACE:clear H:landmarks", + (w - 480, h - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (180, 180, 180), 1) - smoothed_sign = Counter(prediction_buffer).most_common(1)[0][0] - else: - smoothed_sign = predicted_sign + cv2.imshow("ASL Recognition", image) - # Get top 5 predictions - top5_prob, top5_idx = torch.topk(probabilities, min(5, num_classes)) + key = cv2.waitKey(1) & 0xFF + if key == 27: + break + elif key == 32: + sequence_buffer.clear() + prediction_buffer.clear() + print("Buffers cleared") + elif key in (ord('h'), ord('H')): + show_landmarks = not show_landmarks + print(f"Landmarks display: {'ON' if show_landmarks else 'OFF'}") - # Display prediction area (dark semi-transparent overlay) - overlay = image.copy() - cv2.rectangle(overlay, (10, 10), (500, 280), (0, 0, 0), -1) - cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image) - - # Display main prediction - cv2.putText(image, f"Sign: {smoothed_sign}", - (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 3) - cv2.putText(image, f"Confidence: {confidence:.1%}", - (20, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2) - - # Display top 5 predictions - cv2.putText(image, "Top 5:", - (20, 130), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - - y_offset = 160 - for i, (prob, idx) in enumerate(zip(top5_prob, top5_idx)): - sign = label_encoder.inverse_transform([idx.item()])[0] - prob_val = prob.item() - - # Color code by confidence - if i == 0: - color = (0, 255, 0) # Green for top - elif prob_val > 0.1: - color = (0, 255, 255) # Yellow for decent confidence - else: - color = (128, 128, 128) # Gray for low confidence - - cv2.putText(image, f"{i + 1}. {sign}: {prob_val:.1%}", - (30, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) - y_offset += 30 - else: - # No hand detected - cv2.putText(image, "No hand detected", - (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2) - prediction_buffer.clear() - - # Display FPS and info - info_y = image.shape[0] - 60 - cv2.putText(image, f"FPS: {current_fps:.1f}", - (20, info_y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) - cv2.putText(image, f"Frame: {frame_count}", - (20, info_y + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) - - # Display controls at bottom right - controls_text = "ESC: Exit | SPACE: Clear | H: Landmarks" - text_size = cv2.getTextSize(controls_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] - cv2.putText(image, controls_text, - (image.shape[1] - text_size[0] - 10, image.shape[0] - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1) - - # Show the image - cv2.imshow('ASL Recognition - Transformer+CNN', image) - - # Handle key presses - key = cv2.waitKey(1) & 0xFF - - if key == 27: # ESC - print("Exiting...") - break - elif key == 32: # SPACE - prediction_buffer.clear() - print("Prediction buffer cleared") - elif key == ord('h') or key == ord('H'): - show_landmarks = not show_landmarks - print(f"Hand landmarks: {'ON' if show_landmarks else 'OFF'}") - -# Cleanup cap.release() cv2.destroyAllWindows() +landmarker.close() print("Recognition stopped.") \ No newline at end of file diff --git a/training.py b/training.py index 6e2607c..4d156dd 100644 --- a/training.py +++ b/training.py @@ -35,188 +35,222 @@ else: print("=" * 60) +# =============================== +# SELECTED LANDMARK INDICES +# =============================== +IMPORTANT_FACE_INDICES = sorted(list(set([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 55, 65, 66, 105, 107, 336, 296, 334, + 33, 133, 160, 159, 158, 144, 145, 153, + 362, 263, 387, 386, 385, 373, 374, 380, + 1, 2, 98, 327, + 61, 185, 40, 39, 37, 0, 267, 269, 270, 409, + 291, 146, 91, 181, 84, 17, 314, 405, 321, 375, + 78, 191, 80, 81, 82, 13, 312, 311, 310, 415, + 308, 324, 318, 402, 317, 14, 87, 178, 88, 95 +]))) + +NUM_FACE_POINTS = len(IMPORTANT_FACE_INDICES) +NUM_HAND_POINTS = 21 * 2 +TOTAL_POINTS_PER_FRAME = NUM_HAND_POINTS + NUM_FACE_POINTS + # =============================== -# DATA LOADING - HANDLES PARTIAL NaN +# ENHANCED DATA EXTRACTION # =============================== -def load_kaggle_asl_data(base_path): - train_df = pd.read_csv(os.path.join(base_path, "train.csv")) - with open(os.path.join(base_path, "sign_to_prediction_index_map.json")) as f: - sign_to_idx = json.load(f) - return train_df, sign_to_idx - - -def extract_hand_landmarks_from_parquet(path): - """Extract hand landmarks - ONLY uses frames with valid (non-NaN) data""" +def extract_multi_landmarks(path, min_valid_frames=5): + """ + Extract both hands + selected face landmarks with modality flags + Returns: dict with 'landmarks', 'left_hand_valid', 'right_hand_valid', 'face_valid' + """ try: df = pd.read_parquet(path) + seq = [] + left_valid_frames = [] + right_valid_frames = [] + face_valid_frames = [] - # Get hand data - left = df[df["type"] == "left_hand"] - right = df[df["type"] == "right_hand"] - - if len(left) == 0 and len(right) == 0: + all_types = df["type"].unique() + if "left_hand" in all_types or "right_hand" in all_types or "face" in all_types: + frames = sorted(df["frame"].unique()) + else: return None - # Count valid (non-NaN) rows for each hand - left_valid = 0 - right_valid = 0 - - if len(left) > 0: - left_valid = left[['x', 'y', 'z']].notna().all(axis=1).sum() - if len(right) > 0: - right_valid = right[['x', 'y', 'z']].notna().all(axis=1).sum() - - # No valid data at all - if left_valid == 0 and right_valid == 0: + if frames is None or len(frames) < min_valid_frames: return None - # Choose hand with more valid data - hand = left if left_valid >= right_valid else right - - # Get unique frames - frames = sorted(hand['frame'].unique()) - landmarks_seq = [] - for frame in frames: - lm_frame = hand[hand['frame'] == frame] + frame_df = df[df["frame"] == frame] + frame_points = np.full((TOTAL_POINTS_PER_FRAME, 3), np.nan, dtype=np.float32) - # Count how many valid landmarks this frame has - valid_count = lm_frame[['x', 'y', 'z']].notna().all(axis=1).sum() + pos = 0 + left_valid = False + right_valid = False + face_valid = False - # Skip frames with too few valid landmarks - if valid_count < 10: - continue + # Left hand + left = frame_df[frame_df["type"] == "left_hand"] + if len(left) >= 15: + valid_count = 0 + for i in range(21): + row = left[left["landmark_index"] == i] + if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all(): + frame_points[pos] = row[['x', 'y', 'z']].values[0] + valid_count += 1 + pos += 1 + left_valid = (valid_count >= 15) + else: + pos += 21 - # Extract landmarks for this frame - frame_landmarks = [] - valid_landmarks_in_frame = 0 + # Right hand + right = frame_df[frame_df["type"] == "right_hand"] + if len(right) >= 15: + valid_count = 0 + for i in range(21): + row = right[right["landmark_index"] == i] + if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all(): + frame_points[pos] = row[['x', 'y', 'z']].values[0] + valid_count += 1 + pos += 1 + right_valid = (valid_count >= 15) + else: + pos += 21 - for i in range(21): - lm = lm_frame[lm_frame['landmark_index'] == i] + # Face + face = frame_df[frame_df["type"] == "face"] + if len(face) > 0: + valid_count = 0 + for idx in IMPORTANT_FACE_INDICES: + row = face[face["landmark_index"] == idx] + if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all(): + frame_points[pos] = row[['x', 'y', 'z']].values[0] + valid_count += 1 + pos += 1 + face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.5) - if len(lm) == 0: - frame_landmarks.append([0.0, 0.0, 0.0]) - else: - x = float(lm['x'].iloc[0]) - y = float(lm['y'].iloc[0]) - z = float(lm['z'].iloc[0]) + valid_ratio = 1 - np.isnan(frame_points).mean() + if valid_ratio >= 0.40: + frame_points = np.nan_to_num(frame_points, nan=0.0) + seq.append(frame_points) + left_valid_frames.append(left_valid) + right_valid_frames.append(right_valid) + face_valid_frames.append(face_valid) - # Check if valid - if pd.notna(x) and pd.notna(y) and pd.notna(z): - frame_landmarks.append([x, y, z]) - valid_landmarks_in_frame += 1 - else: - frame_landmarks.append([0.0, 0.0, 0.0]) - - # Only add frame if it has enough valid landmarks - if valid_landmarks_in_frame >= 10: - landmarks_seq.append(frame_landmarks) - - # Need at least 3 valid frames - if len(landmarks_seq) < 3: + if len(seq) < min_valid_frames: return None - return np.array(landmarks_seq, dtype=np.float32) + return { + 'landmarks': np.stack(seq), + 'left_hand_valid': np.array(left_valid_frames), + 'right_hand_valid': np.array(right_valid_frames), + 'face_valid': np.array(face_valid_frames) + } - except Exception as e: + except Exception: return None -def get_features_sequence(landmarks_seq, max_frames=100): - """Extract features from landmark sequence""" - if landmarks_seq is None or len(landmarks_seq) == 0: - return None, None - - # Center on wrist (landmark 0) - wrist = landmarks_seq[:, 0:1, :].copy() - landmarks_seq = landmarks_seq - wrist - - # Scale normalization using wrist to middle finger MCP (landmark 9) - scale = np.linalg.norm(landmarks_seq[:, 9, :] - np.zeros(3), axis=1, keepdims=True) - scale = np.maximum(scale, 1e-6) # Avoid division by zero - landmarks_seq = landmarks_seq / scale[:, np.newaxis, :] - - # Clean up any remaining NaN/Inf - landmarks_seq = np.nan_to_num(landmarks_seq, nan=0.0, posinf=0.0, neginf=0.0) - - # Clip extreme values - landmarks_seq = np.clip(landmarks_seq, -10, 10) - - # Calculate finger curl features - tips = [4, 8, 12, 16, 20] # Thumb, index, middle, ring, pinky tips - bases = [1, 5, 9, 13, 17] # Corresponding base joints - - curl_features = [] - for b, t in zip(bases, tips): - curl = np.linalg.norm(landmarks_seq[:, t] - landmarks_seq[:, b], axis=1) - curl_features.append(curl) - curl_features = np.stack(curl_features, axis=1) # (T, 5) - - # Temporal deltas (motion) - deltas = np.zeros_like(landmarks_seq) - if len(landmarks_seq) > 1: - deltas[1:] = landmarks_seq[1:] - landmarks_seq[:-1] - - # Flatten each component separately, then concatenate - landmarks_flat = landmarks_seq.reshape(landmarks_seq.shape[0], -1) # (T, 63) - deltas_flat = deltas.reshape(deltas.shape[0], -1) # (T, 63) - # curl_features is already (T, 5) - - # Combine: 63 + 63 + 5 = 131 features per frame - seq = np.concatenate([ - landmarks_flat, - deltas_flat, - curl_features - ], axis=1) - - # Pad or truncate to max_frames - T, F = seq.shape - if T < max_frames: - # Pad with zeros - pad = np.zeros((max_frames - T, F), dtype=np.float32) - seq = np.concatenate([seq, pad], axis=0) - elif T > max_frames: - # Truncate - seq = seq[:max_frames, :] - - # Create attention mask (True for valid positions) - valid_mask = np.zeros(max_frames, dtype=bool) - valid_mask[:min(T, max_frames)] = True - - return seq.astype(np.float32), valid_mask - - -def process_row(row, base_path, max_frames=100): - """Process a single row - worker function for multiprocessing""" - path = os.path.join(base_path, row["path"]) - - if not os.path.exists(path): +def get_features_sequence(landmarks_data, max_frames=100): + """ + Enhanced feature extraction with separate modality processing + """ + if landmarks_data is None: return None, None, None + landmarks_3d = landmarks_data['landmarks'] + if len(landmarks_3d) == 0: + return None, None, None + + T, N, _ = landmarks_3d.shape + + # Separate modalities for independent normalization + left_hand = landmarks_3d[:, :21, :] + right_hand = landmarks_3d[:, 21:42, :] + face = landmarks_3d[:, 42:, :] + + # Independent centering per modality + features_list = [] + + for modality, valid_mask in [ + (left_hand, landmarks_data['left_hand_valid']), + (right_hand, landmarks_data['right_hand_valid']), + (face, landmarks_data['face_valid']) + ]: + # Center on modality-specific mean + valid_frames = modality[valid_mask] if valid_mask.any() else modality + if len(valid_frames) > 0: + center = np.mean(valid_frames, axis=(0, 1), keepdims=True) + spread = np.std(valid_frames, axis=(0, 1), keepdims=True).max() + else: + center = 0 + spread = 1 + + modality_norm = (modality - center) / max(spread, 1e-6) + flat = modality_norm.reshape(T, -1) + + # Deltas + deltas = np.zeros_like(flat) + if T > 1: + deltas[1:] = flat[1:] - flat[:-1] + + features_list.append(flat) + features_list.append(deltas) + + # Combine all modalities + features = np.concatenate(features_list, axis=1) + + # Create modality availability mask (which frames have which modalities) + modality_mask = np.stack([ + landmarks_data['left_hand_valid'], + landmarks_data['right_hand_valid'], + landmarks_data['face_valid'] + ], axis=1).astype(np.float32) # (T, 3) + + # Pad/truncate + if T < max_frames: + pad = np.zeros((max_frames - T, features.shape[1]), dtype=np.float32) + features = np.concatenate([features, pad], axis=0) + + mask_pad = np.zeros((max_frames - T, 3), dtype=np.float32) + modality_mask = np.concatenate([modality_mask, mask_pad], axis=0) + + frame_mask = np.zeros(max_frames, dtype=bool) + frame_mask[:T] = True + else: + features = features[:max_frames] + modality_mask = modality_mask[:max_frames] + frame_mask = np.ones(max_frames, dtype=bool) + + features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) + features = np.clip(features, -30, 30) + + return features.astype(np.float32), frame_mask, modality_mask + + +def process_row(row_data, base_path, max_frames=100): + """Process a single row - expects tuple of (path, sign)""" + path_rel, sign = row_data + path = os.path.join(base_path, path_rel) + if not os.path.exists(path): + return None, None, None, None + try: - # Extract landmarks - lm = extract_hand_landmarks_from_parquet(path) - if lm is None: - return None, None, None + lm_data = extract_multi_landmarks(path) + if lm_data is None: + return None, None, None, None - # Get features - feat, mask = get_features_sequence(lm, max_frames) + feat, frame_mask, modality_mask = get_features_sequence(lm_data, max_frames) if feat is None: - return None, None, None + return None, None, None, None - # Final safety check - if np.isnan(feat).any() or np.isinf(feat).any(): - return None, None, None + return feat, frame_mask, modality_mask, sign - return feat, mask, row["sign"] - - except Exception as e: - return None, None, None + except Exception: + return None, None, None, None # =============================== -# TRANSFORMER MODEL +# ENHANCED MODEL WITH MODALITY AWARENESS # =============================== class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=128): @@ -232,16 +266,19 @@ class PositionalEncoding(nn.Module): return x + self.pe[:, :x.size(1)] -class TransformerASL(nn.Module): - def __init__(self, input_dim, num_classes, d_model=256, nhead=8, num_layers=4): +class ModalityAwareTransformer(nn.Module): + def __init__(self, input_dim, num_classes, d_model=384, nhead=8, num_layers=5): super().__init__() - # Input projection + # Main projection self.proj = nn.Linear(input_dim, d_model) - self.norm_in = nn.LayerNorm(d_model) - self.pos = PositionalEncoding(d_model, max_len=128) - # Transformer encoder + # Modality embedding (3 modalities: left_hand, right_hand, face) + self.modality_embed = nn.Linear(3, d_model) + + self.norm_in = nn.LayerNorm(d_model) + self.pos = PositionalEncoding(d_model) + enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, @@ -253,30 +290,45 @@ class TransformerASL(nn.Module): ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) - # Classification head self.head = nn.Sequential( nn.LayerNorm(d_model), nn.Dropout(0.25), nn.Linear(d_model, num_classes) ) - def forward(self, x, key_padding_mask=None): - # x: (batch, seq_len, input_dim) - # key_padding_mask: (batch, seq_len) - True for padding positions - + def forward(self, x, modality_mask=None, key_padding_mask=None): + # Project features x = self.proj(x) + + # Add modality information if available + if modality_mask is not None: + mod_embed = self.modality_embed(modality_mask) + x = x + mod_embed + x = self.norm_in(x) x = self.pos(x) x = self.encoder(x, src_key_padding_mask=key_padding_mask) - # Global average pooling over valid positions - x = x.mean(dim=1) + # Weighted average (giving more weight to frames with more valid modalities) + if modality_mask is not None: + weights = modality_mask.sum(dim=-1, keepdim=True) + 1e-6 # (B, T, 1) + weights = weights / weights.sum(dim=1, keepdim=True) + x = (x * weights).sum(dim=1) + else: + x = x.mean(dim=1) return self.head(x) +def load_kaggle_asl_data(base_path): + """Load training metadata""" + train_path = os.path.join(base_path, "train.csv") + train_df = pd.read_csv(train_path) + return train_df, None + + # =============================== -# MAIN TRAINING +# MAIN # =============================== def main(): base_path = "asl_kaggle" @@ -284,163 +336,142 @@ def main(): MIN_SAMPLES_PER_CLASS = 5 print("\nLoading metadata...") - train_df, sign_to_idx = load_kaggle_asl_data(base_path) - print(f"Total sequences: {len(train_df)}") + train_df, _ = load_kaggle_asl_data(base_path) - rows = [row for _, row in train_df.iterrows()] + # Convert to simple tuples for multiprocessing compatibility + rows = [(row['path'], row['sign']) for _, row in train_df.iterrows()] - print("\nProcessing sequences (this will take a few minutes)...") - print("Expected: ~36,000 valid sequences based on diagnostic") - - # Process with multiprocessing + print("\nProcessing sequences with BOTH hands + FACE (enhanced)...") with Pool(cpu_count()) as pool: results = list(tqdm( pool.imap( partial(process_row, base_path=base_path, max_frames=max_frames), rows, - chunksize=100 + chunksize=80 ), total=len(rows), - desc="Extracting landmarks" + desc="Landmarks extraction" )) - # Filter valid results - X_list, masks_list, y_list = [], [], [] - for feat, mask, sign in results: - if feat is not None and mask is not None and sign is not None: - if feat.shape[0] == max_frames: - X_list.append(feat) - masks_list.append(mask) - y_list.append(sign) + X_list, frame_masks_list, modality_masks_list, y_list = [], [], [], [] + for feat, frame_mask, modality_mask, sign in results: + if feat is not None and frame_mask is not None: + X_list.append(feat) + frame_masks_list.append(frame_mask) + modality_masks_list.append(modality_mask) + y_list.append(sign) - print(f"\n✓ Successfully extracted: {len(X_list)} valid sequences") - print(f" Success rate: {len(X_list) / len(train_df) * 100:.1f}%") - - if len(X_list) < 100: - print("❌ Too few valid sequences found!") - print(" This shouldn't happen - please share this output for debugging") + if not X_list: + print("No valid sequences extracted!") return - # Stack into arrays X = np.stack(X_list) - masks = np.stack(masks_list) + frame_masks = np.stack(frame_masks_list) + modality_masks = np.stack(modality_masks_list) - print(f"\nData shape: {X.shape}") - print(f"Feature dimension: {X.shape[2]}") + print(f"\nExtracted {len(X):,} sequences") + print(f"Feature shape: {X.shape[1:]} (input_dim = {X.shape[2]})") + print(f"Modality mask shape: {modality_masks.shape}") # Global normalization - print("Normalizing features...") - X = np.clip(X, -10.0, 10.0) + X = np.clip(X, -30, 30) mean = X.mean(axis=(0, 1), keepdims=True) std = X.std(axis=(0, 1), keepdims=True) + 1e-8 X = (X - mean) / std - # Encode labels + # Labels le = LabelEncoder() y = le.fit_transform(y_list) - # Filter classes with too few samples + # Filter rare classes counts = Counter(y) - valid_classes = [cls for cls, cnt in counts.items() if cnt >= MIN_SAMPLES_PER_CLASS] - mask_valid = np.isin(y, valid_classes) + valid = [k for k, v in counts.items() if v >= MIN_SAMPLES_PER_CLASS] + mask = np.isin(y, valid) - X = X[mask_valid] - masks = masks[mask_valid] - y = y[mask_valid] + X = X[mask] + frame_masks = frame_masks[mask] + modality_masks = modality_masks[mask] + y = y[mask] - # Re-encode after filtering le = LabelEncoder() y = le.fit_transform(y) - print(f"\nFinal dataset after filtering:") - print(f" Samples: {len(X):,}") - print(f" Classes: {len(le.classes_)}") - print(f" Sign examples: {list(le.classes_[:10])}") + print(f"After filtering: {len(X):,} samples | {len(le.classes_)} classes") - # Train-test split - X_train, X_test, masks_train, masks_test, y_train, y_test = train_test_split( - X, masks, y, test_size=0.15, stratify=y, random_state=42 + # Analyze modality usage + print("\nModality statistics:") + print(f" Sequences with left hand: {(modality_masks[:, :, 0].sum(axis=1) > 0).mean() * 100:.1f}%") + print(f" Sequences with right hand: {(modality_masks[:, :, 1].sum(axis=1) > 0).mean() * 100:.1f}%") + print(f" Sequences with face: {(modality_masks[:, :, 2].sum(axis=1) > 0).mean() * 100:.1f}%") + + # Split + X_tr, X_te, fm_tr, fm_te, mm_tr, mm_te, y_tr, y_te = train_test_split( + X, frame_masks, modality_masks, y, test_size=0.15, stratify=y, random_state=42 ) - print(f"\nTrain set: {len(X_train):,} samples") - print(f"Test set: {len(X_test):,} samples") - - # Dataset wrapper - class ASLSequenceDataset(Dataset): - def __init__(self, X, masks, y): + # Dataset + class ASLMultiDataset(Dataset): + def __init__(self, X, frame_masks, modality_masks, y): self.X = torch.from_numpy(X).float() - self.masks = torch.from_numpy(masks) + self.frame_masks = torch.from_numpy(frame_masks).bool() + self.modality_masks = torch.from_numpy(modality_masks).float() self.y = torch.from_numpy(y).long() def __len__(self): return len(self.X) def __getitem__(self, idx): - return self.X[idx], self.masks[idx], self.y[idx] + return self.X[idx], self.frame_masks[idx], self.modality_masks[idx], self.y[idx] - # DataLoaders - batch_size = 128 if device.type == 'cuda' else 64 + batch_size = 64 if device.type == 'cuda' else 32 train_loader = DataLoader( - ASLSequenceDataset(X_train, masks_train, y_train), - batch_size=batch_size, - shuffle=True, - num_workers=4, - pin_memory=True if device.type == 'cuda' else False + ASLMultiDataset(X_tr, fm_tr, mm_tr, y_tr), + batch_size=batch_size, shuffle=True, + num_workers=4, pin_memory=device.type == 'cuda' ) test_loader = DataLoader( - ASLSequenceDataset(X_test, masks_test, y_test), - batch_size=batch_size * 2, - shuffle=False, - num_workers=4, - pin_memory=True if device.type == 'cuda' else False + ASLMultiDataset(X_te, fm_te, mm_te, y_te), + batch_size=batch_size * 2, shuffle=False, + num_workers=4, pin_memory=device.type == 'cuda' ) - # Initialize model - print("\nInitializing model...") - model = TransformerASL( + # Enhanced model + model = ModalityAwareTransformer( input_dim=X.shape[2], num_classes=len(le.classes_), - d_model=256, + d_model=384, nhead=8, - num_layers=4 + num_layers=5 ).to(device) - total_params = sum(p.numel() for p in model.parameters()) - print(f"Model parameters: {total_params:,}") + print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}") - # Training setup criterion = nn.CrossEntropyLoss(label_smoothing=0.05) - optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4) - scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) + optimizer = optim.AdamW(model.parameters(), lr=4e-4, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10) - # Training loop best_acc = 0.0 - epochs = 60 - - print("\n" + "=" * 60) - print("STARTING TRAINING") - print("=" * 60) + epochs = 70 for epoch in range(epochs): - # Train model.train() - total_loss = 0 - correct = total = 0 + total_loss = correct = total = 0 - for x, mask, yb in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False): - x, mask, yb = x.to(device), mask.to(device), yb.to(device) + for x, frame_mask, modality_mask, yb in tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False): + x = x.to(device) + frame_mask = frame_mask.to(device) + modality_mask = modality_mask.to(device) + yb = yb.to(device) - # Invert mask: True for padding positions - key_mask = ~mask + key_padding_mask = ~frame_mask optimizer.zero_grad(set_to_none=True) - logits = model(x, key_padding_mask=key_mask) + logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask) loss = criterion(logits, yb) loss.backward() - - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() @@ -449,49 +480,35 @@ def main(): train_acc = correct / total * 100 - # Evaluate + # Eval model.eval() correct = total = 0 with torch.no_grad(): - for x, mask, yb in test_loader: - x, mask, yb = x.to(device), mask.to(device), yb.to(device) - key_mask = ~mask - logits = model(x, key_padding_mask=key_mask) + for x, frame_mask, modality_mask, yb in test_loader: + x = x.to(device) + frame_mask = frame_mask.to(device) + modality_mask = modality_mask.to(device) + yb = yb.to(device) + + key_padding_mask = ~frame_mask + logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask) correct += (logits.argmax(-1) == yb).sum().item() total += yb.size(0) test_acc = correct / total * 100 + scheduler.step() - # Print progress print(f"[{epoch + 1:2d}/{epochs}] Loss: {total_loss / len(train_loader):.4f} | " f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}%", end="") - scheduler.step() - - # Save best model if test_acc > best_acc: best_acc = test_acc - torch.save({ - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'label_encoder_classes': le.classes_, - 'acc': best_acc, - 'epoch': epoch, - 'input_dim': X.shape[2], - 'num_classes': len(le.classes_), - 'd_model': 256, - 'nhead': 8, - 'num_layers': 4 - }, "best_asl_transformer.pth") - print(f" → New best: {best_acc:.2f}% ✓") + torch.save(model.state_dict(), "best_asl_modality_aware.pth") + print(" → saved") else: print() - print("\n" + "=" * 60) - print(f"✓ Training complete!") - print(f"✓ Best test accuracy: {best_acc:.2f}%") - print(f"✓ Model saved: best_asl_transformer.pth") - print("=" * 60) + print(f"\nBest test accuracy: {best_acc:.2f}%") if __name__ == "__main__":