hand and face again

This commit is contained in:
2026-01-18 14:43:14 -06:00
parent 716428ec0b
commit 9256050292
3 changed files with 551 additions and 647 deletions

1
.gitignore vendored
View File

@@ -2,3 +2,4 @@ asl_kaggle/
hand_landmarker.task hand_landmarker.task
asl-dataset.zip asl-dataset.zip
asl-signs.zip asl-signs.zip
best_asl_transformer.pth

622
test.py
View File

@@ -1,170 +1,208 @@
import mediapipe as mp
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from collections import deque, Counter
import pandas as pd # ← added for rebuilding labels
# Modern MediaPipe Tasks API (no legacy solutions module)
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
# PyTorch ≥ 2.6 checkpoint loading fix
import numpy as np
import numpy.core.multiarray
import numpy.dtypes
torch.serialization.add_safe_globals([
np.ndarray,
np.dtype,
np.dtypes.Int64DType,
np.core.multiarray._reconstruct
])
# Positional Encoding # ===============================
# MODEL DEFINITION
# ===============================
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100): def __init__(self, d_model, max_len=128):
super(PositionalEncoding, self).__init__() super().__init__()
pe = torch.zeros(max_len, d_model) pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x): def forward(self, x):
return x + self.pe[:, :x.size(1), :] return x + self.pe[:, :x.size(1)]
# Model architecture class TransformerASL(nn.Module):
class TransformerCNN_ASL(nn.Module): def __init__(self, input_dim, num_classes, d_model=256, nhead=8, num_layers=4):
def __init__(self, input_dim=77, num_classes=250, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048): super().__init__()
super(TransformerCNN_ASL, self).__init__() self.proj = nn.Linear(input_dim, d_model)
self.norm_in = nn.LayerNorm(d_model)
self.pos = PositionalEncoding(d_model, max_len=128)
self.input_dim = input_dim enc_layer = nn.TransformerEncoderLayer(
self.d_model = d_model
# Input projection
self.input_projection = nn.Linear(input_dim, d_model)
self.input_norm = nn.LayerNorm(d_model)
# Positional encoding
self.pos_encoder = PositionalEncoding(d_model, max_len=100)
# Transformer Encoder with Self-Attention
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, d_model=d_model,
nhead=nhead, nhead=nhead,
dim_feedforward=dim_feedforward, dim_feedforward=d_model * 4,
dropout=0.1, dropout=0.15,
activation='gelu', activation='gelu',
batch_first=True, batch_first=True,
norm_first=True norm_first=True
) )
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
# CNN Blocks for pattern detection self.head = nn.Sequential(
self.conv1 = nn.Conv1d(d_model, 1024, kernel_size=3, padding=1) nn.LayerNorm(d_model),
self.bn1 = nn.BatchNorm1d(1024) nn.Dropout(0.25),
self.pool1 = nn.MaxPool1d(2) nn.Linear(d_model, num_classes)
self.dropout1 = nn.Dropout(0.3)
self.conv2 = nn.Conv1d(1024, 2048, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(2048)
self.pool2 = nn.MaxPool1d(2)
self.dropout2 = nn.Dropout(0.3)
self.conv3 = nn.Conv1d(2048, 4096, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(4096)
self.pool3 = nn.AdaptiveMaxPool1d(1) # Global pooling
self.dropout3 = nn.Dropout(0.4)
# Fully connected layers
self.fc1 = nn.Linear(4096, 4096)
self.bn_fc1 = nn.BatchNorm1d(4096)
self.dropout_fc1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(4096, 2048)
self.bn_fc2 = nn.BatchNorm1d(2048)
self.dropout_fc2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(2048, 1024)
self.bn_fc3 = nn.BatchNorm1d(1024)
self.dropout_fc3 = nn.Dropout(0.3)
self.fc4 = nn.Linear(1024, num_classes)
def forward(self, x):
batch_size = x.size(0)
# Project to d_model
x = self.input_projection(x)
x = self.input_norm(x)
x = x.unsqueeze(1)
# Add positional encoding
x = self.pos_encoder(x)
# Transformer encoder with self-attention
x = self.transformer_encoder(x)
# Reshape for CNN
x = x.permute(0, 2, 1)
# CNN pattern detection
x = F.gelu(self.bn1(self.conv1(x)))
x = self.pool1(x)
x = self.dropout1(x)
x = F.gelu(self.bn2(self.conv2(x)))
x = self.pool2(x)
x = self.dropout2(x)
x = F.gelu(self.bn3(self.conv3(x)))
x = self.pool3(x)
x = self.dropout3(x)
# Flatten
x = x.view(batch_size, -1)
# Fully connected layers
x = F.gelu(self.bn_fc1(self.fc1(x)))
x = self.dropout_fc1(x)
x = F.gelu(self.bn_fc2(self.fc2(x)))
x = self.dropout_fc2(x)
x = F.gelu(self.bn_fc3(self.fc3(x)))
x = self.dropout_fc3(x)
x = self.fc4(x)
return x
# Load the trained model
print("Loading model...")
checkpoint = torch.load('asl_kaggle_transformer.pth', map_location='cpu')
label_encoder = checkpoint['label_encoder']
num_classes = checkpoint['num_classes']
input_dim = checkpoint['input_dim']
config = checkpoint['model_config']
model = TransformerCNN_ASL(
input_dim=input_dim,
num_classes=num_classes,
d_model=config['d_model'],
nhead=config['nhead'],
num_layers=config['num_layers'],
dim_feedforward=config['dim_feedforward']
) )
model.load_state_dict(checkpoint['model_state_dict'])
def forward(self, x, key_padding_mask=None):
x = self.proj(x)
x = self.norm_in(x)
x = self.pos(x)
x = self.encoder(x, src_key_padding_mask=key_padding_mask)
x = x.mean(dim=1)
return self.head(x)
# ===============================
# FEATURE EXTRACTION
# ===============================
def get_features_sequence(landmarks_seq, max_frames=100):
if landmarks_seq is None or len(landmarks_seq) == 0:
return None, None
wrist = landmarks_seq[:, 0:1, :]
landmarks_seq = landmarks_seq - wrist
scale = np.linalg.norm(landmarks_seq[:, 9], axis=1, keepdims=True)
scale = np.maximum(scale, 1e-6)
landmarks_seq = landmarks_seq / scale[:, :, np.newaxis]
landmarks_seq = np.nan_to_num(landmarks_seq, nan=0.0, posinf=0.0, neginf=0.0)
landmarks_seq = np.clip(landmarks_seq, -10, 10)
tips = [4, 8, 12, 16, 20]
bases = [1, 5, 9, 13, 17]
curls = [np.linalg.norm(landmarks_seq[:, t] - landmarks_seq[:, b], axis=1)
for b, t in zip(bases, tips)]
curl_features = np.stack(curls, axis=1)
deltas = np.zeros_like(landmarks_seq)
if len(landmarks_seq) > 1:
deltas[1:] = landmarks_seq[1:] - landmarks_seq[:-1]
pos_flat = landmarks_seq.reshape(len(landmarks_seq), -1)
delta_flat = deltas.reshape(len(landmarks_seq), -1)
seq = np.concatenate([pos_flat, delta_flat, curl_features], axis=1)
T, F = seq.shape
if T < max_frames:
pad = np.zeros((max_frames - T, F), dtype=np.float32)
seq_padded = np.concatenate([seq, pad], axis=0)
else:
seq_padded = seq[:max_frames]
mask = np.zeros(max_frames, dtype=bool)
mask[:min(T, max_frames)] = True
return seq_padded.astype(np.float32), mask
# ===============================
# MANUAL DRAWING FUNCTION
# ===============================
HAND_CONNECTIONS = [
(0, 1), (1, 2), (2, 3), (3, 4),
(0, 5), (5, 6), (6, 7), (7, 8),
(0, 9), (9, 10), (10, 11), (11, 12),
(0, 13), (13, 14), (14, 15), (15, 16),
(0, 17), (17, 18), (18, 19), (19, 20),
(5, 9), (9, 13), (13, 17)
]
def draw_hand_landmarks(image, landmarks_list):
h, w = image.shape[:2]
# Draw connections (blue lines)
for start_idx, end_idx in HAND_CONNECTIONS:
start = landmarks_list[start_idx]
end = landmarks_list[end_idx]
start_pt = (int(start.x * w), int(start.y * h))
end_pt = (int(end.x * w), int(end.y * h))
cv2.line(image, start_pt, end_pt, (255, 0, 0), 2)
# Draw landmarks (green circles)
for lm in landmarks_list:
x = int(lm.x * w)
y = int(lm.y * h)
cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
# ===============================
# MAIN PROGRAM
# ===============================
print("Loading trained model...")
checkpoint = torch.load("best_asl_transformer.pth", map_location="cpu")
model = TransformerASL(
input_dim=checkpoint['input_dim'],
num_classes=checkpoint['num_classes'],
d_model=checkpoint['d_model'],
nhead=checkpoint['nhead'],
num_layers=checkpoint['num_layers']
)
model.load_state_dict(checkpoint['model'])
model.eval() model.eval()
total_params = sum(p.numel() for p in model.parameters()) # ─── FIX: Rebuild real sign names from train.csv ─────────────────────
print(f"Loaded Transformer+CNN model") print("\n" + "=" * 70)
print(f"Total parameters: {total_params:,}") print("Rebuilding sign name mapping from train.csv...")
print(f"Number of ASL signs: {num_classes}")
print(f"Sample signs: {label_encoder.classes_[:10]}")
# Setup MediaPipe try:
BaseOptions = mp.tasks.BaseOptions # CHANGE THIS PATH to where your train.csv actually is
HandLandmarker = mp.tasks.vision.HandLandmarker train_df = pd.read_csv("asl_kaggle/train.csv") # ← most important line!
HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode # Get unique signs, sorted (same order LabelEncoder usually uses)
real_signs = sorted(train_df['sign'].unique())
# Use real sign names instead of numbers
label_encoder_classes = real_signs
print("SUCCESS! Loaded real sign names")
print("Number of classes:", len(real_signs))
print("First 15 signs:", real_signs[:15])
print("=" * 70 + "\n")
except Exception as e:
print("ERROR loading train.csv:", e)
print("Falling back to numeric labels (you'll see numbers instead of words)")
label_encoder_classes = checkpoint['label_encoder_classes']
print("First 15 (still numbers):", label_encoder_classes[:15])
print("=" * 70 + "\n")
# MediaPipe Tasks setup
BaseOptions = python.BaseOptions
HandLandmarker = vision.HandLandmarker
HandLandmarkerOptions = vision.HandLandmarkerOptions
VisionRunningMode = vision.RunningMode
MODEL_PATH = "hand_landmarker.task" # Make sure this file is in the folder
options = HandLandmarkerOptions( options = HandLandmarkerOptions(
base_options=BaseOptions(model_asset_path='hand_landmarker.task'), base_options=BaseOptions(model_asset_path=MODEL_PATH),
running_mode=VisionRunningMode.VIDEO, running_mode=VisionRunningMode.VIDEO,
num_hands=1, num_hands=1,
min_hand_detection_confidence=0.5, min_hand_detection_confidence=0.5,
@@ -174,270 +212,118 @@ options = HandLandmarkerOptions(
landmarker = HandLandmarker.create_from_options(options) landmarker = HandLandmarker.create_from_options(options)
# Buffers
MAX_FRAMES = 100
sequence_buffer = []
prediction_buffer = deque(maxlen=15)
def get_optimized_features(hand_landmarks):
"""
Extract optimally normalized relative coordinates from MediaPipe landmarks
Returns 77 features
"""
# Extract raw coordinates
points = np.array([[lm.x, lm.y, lm.z] for lm in hand_landmarks], dtype=np.float32)
# Step 1: Translation invariance - center on wrist
wrist = points[0].copy()
points_centered = points - wrist
# Step 2: Scale invariance - normalize by palm size
palm_size = np.linalg.norm(points[9] - points[0]) # wrist to middle finger base
if palm_size < 1e-6:
palm_size = 1.0
points_normalized = points_centered / palm_size
# Step 3: Standardization
mean = np.mean(points_normalized, axis=0)
std = np.std(points_normalized, axis=0) + 1e-8
points_standardized = (points_normalized - mean) / std
# Flatten base features (63 features)
features = points_standardized.flatten()
# Step 4: Derived features
finger_tips = [4, 8, 12, 16, 20] # Thumb, Index, Middle, Ring, Pinky
# Distances between consecutive fingertips (4 distances)
tip_distances = []
for i in range(len(finger_tips) - 1):
dist = np.linalg.norm(points_normalized[finger_tips[i]] - points_normalized[finger_tips[i + 1]])
tip_distances.append(dist)
# Distance of each fingertip from palm center (5 distances)
palm_center = np.mean(points_normalized[[0, 5, 9, 13, 17]], axis=0)
tip_to_palm = []
for tip in finger_tips:
dist = np.linalg.norm(points_normalized[tip] - palm_center)
tip_to_palm.append(dist)
# Finger curl indicators (5 curls)
finger_curls = []
finger_bases = [1, 5, 9, 13, 17]
for base, tip in zip(finger_bases, finger_tips):
curl = np.linalg.norm(points_normalized[tip] - points_normalized[base])
finger_curls.append(curl)
# Combine all features: 63 + 4 + 5 + 5 = 77
all_features = np.concatenate([
features,
tip_distances,
tip_to_palm,
finger_curls
])
return all_features.astype(np.float32)
# Initialize webcam
cap = cv2.VideoCapture(0) cap = cv2.VideoCapture(0)
if not cap.isOpened(): if not cap.isOpened():
print("Error: Cannot open webcam") print("Cannot open webcam")
exit() exit()
# Set camera resolution for better performance print("\nASL Recognition running - Press ESC to quit")
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) print("Controls: ESC = quit | SPACE = clear | H = toggle landmarks\n")
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
frame_count = 0
fps_counter = 0
fps_start_time = cv2.getTickCount()
current_fps = 0
# Prediction smoothing buffer
from collections import deque
prediction_buffer = deque(maxlen=10)
print("\n" + "=" * 60)
print("ASL Recognition - Transformer+CNN Model")
print("=" * 60)
print("Controls:")
print(" ESC - Exit")
print(" SPACE - Clear prediction buffer")
print(" 'h' - Toggle hand landmarks visibility")
print("=" * 60 + "\n")
show_landmarks = True show_landmarks = True
frame_timestamp_ms = 0
with torch.no_grad(): while cap.isOpened():
while True:
success, image = cap.read() success, image = cap.read()
if not success: if not success:
print("Failed to read frame from webcam")
break break
# Flip image horizontally for mirror view
image = cv2.flip(image, 1) image = cv2.flip(image, 1)
# Convert to MediaPipe format
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
# Detect hands
results = landmarker.detect_for_video(mp_image, frame_count)
frame_count += 1
# Calculate FPS
fps_counter += 1
if fps_counter >= 30:
fps_end_time = cv2.getTickCount()
time_diff = (fps_end_time - fps_start_time) / cv2.getTickFrequency()
current_fps = fps_counter / time_diff
fps_counter = 0
fps_start_time = cv2.getTickCount()
# Process hand landmarks if detected
if results.hand_landmarks and len(results.hand_landmarks) > 0:
hand_landmarks = results.hand_landmarks[0]
# Draw hand landmarks if enabled
if show_landmarks:
# Draw connections
connections = [
(0, 1), (1, 2), (2, 3), (3, 4), # Thumb
(0, 5), (5, 6), (6, 7), (7, 8), # Index
(0, 9), (9, 10), (10, 11), (11, 12), # Middle
(0, 13), (13, 14), (14, 15), (15, 16), # Ring
(0, 17), (17, 18), (18, 19), (19, 20), # Pinky
(5, 9), (9, 13), (13, 17) # Palm
]
# Get image dimensions
h, w = image.shape[:2] h, w = image.shape[:2]
# Draw connections mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
for connection in connections: frame_timestamp_ms += 33
start_idx, end_idx = connection
start = hand_landmarks[start_idx]
end = hand_landmarks[end_idx]
start_point = (int(start.x * w), int(start.y * h)) results = landmarker.detect_for_video(mp_image, frame_timestamp_ms)
end_point = (int(end.x * w), int(end.y * h))
cv2.line(image, start_point, end_point, (0, 255, 0), 2)
# Draw landmarks
for i, landmark in enumerate(hand_landmarks):
x = int(landmark.x * w)
y = int(landmark.y * h)
# Different colors for different parts
if i == 0: # Wrist
color = (255, 0, 0)
radius = 8
elif i in [4, 8, 12, 16, 20]: # Fingertips
color = (0, 0, 255)
radius = 6
else:
color = (0, 255, 0)
radius = 4
cv2.circle(image, (x, y), radius, color, -1)
cv2.circle(image, (x, y), radius + 2, (255, 255, 255), 1)
# Extract features
features = get_optimized_features(hand_landmarks)
# Make prediction
input_tensor = torch.FloatTensor(features).unsqueeze(0)
output = model(input_tensor)
probabilities = torch.softmax(output, dim=1)[0]
# Get top prediction
predicted_idx = torch.argmax(probabilities).item()
confidence = probabilities[predicted_idx].item()
predicted_sign = label_encoder.inverse_transform([predicted_idx])[0]
# Add to buffer for smoothing
if confidence > 0.3: # Only add if confident enough
prediction_buffer.append(predicted_sign)
# Get smoothed prediction (most common in buffer)
if len(prediction_buffer) >= 5:
from collections import Counter
smoothed_sign = Counter(prediction_buffer).most_common(1)[0][0]
else:
smoothed_sign = predicted_sign
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, min(5, num_classes))
# Display prediction area (dark semi-transparent overlay)
overlay = image.copy() overlay = image.copy()
cv2.rectangle(overlay, (10, 10), (500, 280), (0, 0, 0), -1) cv2.rectangle(overlay, (10, 10), (520, 340), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image) cv2.addWeighted(overlay, 0.65, image, 0.35, 0, image)
# Display main prediction if results.hand_landmarks:
cv2.putText(image, f"Sign: {smoothed_sign}", hand_landmarks_list = results.hand_landmarks[0]
(20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 3)
cv2.putText(image, f"Confidence: {confidence:.1%}",
(20, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
# Display top 5 predictions if show_landmarks:
cv2.putText(image, "Top 5:", draw_hand_landmarks(image, hand_landmarks_list)
(20, 130), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
y_offset = 160 current_frame = np.array(
for i, (prob, idx) in enumerate(zip(top5_prob, top5_idx)): [[lm.x, lm.y, lm.z] for lm in hand_landmarks_list],
sign = label_encoder.inverse_transform([idx.item()])[0] dtype=np.float32
prob_val = prob.item() )
sequence_buffer.append(current_frame)
if len(sequence_buffer) > MAX_FRAMES:
sequence_buffer = sequence_buffer[-MAX_FRAMES:]
if len(sequence_buffer) >= 10:
seq_np = np.array(sequence_buffer)
feats, mask = get_features_sequence(seq_np, MAX_FRAMES)
if feats is not None:
x = torch.from_numpy(feats).float().unsqueeze(0)
key_padding_mask = torch.from_numpy(~mask).unsqueeze(0)
with torch.no_grad():
logits = model(x, key_padding_mask=key_padding_mask)
probs = F.softmax(logits, dim=-1)[0]
pred_idx = torch.argmax(probs).item()
conf = probs[pred_idx].item()
# Now using real sign names!
sign = label_encoder_classes[pred_idx]
if conf > 0.40:
prediction_buffer.append(sign)
final_sign = sign
final_conf = conf
if len(prediction_buffer) >= 6:
final_sign = Counter(prediction_buffer).most_common(1)[0][0]
try:
final_conf = probs[label_encoder_classes.index(final_sign)].item()
except:
pass
color = (0, 255, 100) if final_conf > 0.75 else (0, 220, 220)
cv2.putText(image, f"Sign: {final_sign}", (25, 60),
cv2.FONT_HERSHEY_SIMPLEX, 1.8, color, 4)
cv2.putText(image, f"Conf: {final_conf:.1%}", (25, 110),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (220, 220, 220), 2)
top3_p, top3_i = torch.topk(probs, 3)
for i, (p, idx) in enumerate(zip(top3_p, top3_i)):
s = label_encoder_classes[idx.item()]
cv2.putText(image, f"{i + 1}. {s:<18} {p:.1%}",
(25, 155 + i * 40), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (200, 200, 200), 2)
# Color code by confidence
if i == 0:
color = (0, 255, 0) # Green for top
elif prob_val > 0.1:
color = (0, 255, 255) # Yellow for decent confidence
else: else:
color = (128, 128, 128) # Gray for low confidence if len(sequence_buffer) < 25:
sequence_buffer.clear()
cv2.putText(image, "No hand detected", (25, 60),
cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3)
cv2.putText(image, f"{i + 1}. {sign}: {prob_val:.1%}", cv2.putText(image, "ESC:quit SPACE:clear H:landmarks",
(30, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) (w - 480, h - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (180, 180, 180), 1)
y_offset += 30
else:
# No hand detected
cv2.putText(image, "No hand detected",
(20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
prediction_buffer.clear()
# Display FPS and info cv2.imshow("ASL Recognition", image)
info_y = image.shape[0] - 60
cv2.putText(image, f"FPS: {current_fps:.1f}",
(20, info_y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
cv2.putText(image, f"Frame: {frame_count}",
(20, info_y + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# Display controls at bottom right
controls_text = "ESC: Exit | SPACE: Clear | H: Landmarks"
text_size = cv2.getTextSize(controls_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
cv2.putText(image, controls_text,
(image.shape[1] - text_size[0] - 10, image.shape[0] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
# Show the image
cv2.imshow('ASL Recognition - Transformer+CNN', image)
# Handle key presses
key = cv2.waitKey(1) & 0xFF key = cv2.waitKey(1) & 0xFF
if key == 27:
if key == 27: # ESC
print("Exiting...")
break break
elif key == 32: # SPACE elif key == 32:
sequence_buffer.clear()
prediction_buffer.clear() prediction_buffer.clear()
print("Prediction buffer cleared") print("Buffers cleared")
elif key == ord('h') or key == ord('H'): elif key in (ord('h'), ord('H')):
show_landmarks = not show_landmarks show_landmarks = not show_landmarks
print(f"Hand landmarks: {'ON' if show_landmarks else 'OFF'}") print(f"Landmarks display: {'ON' if show_landmarks else 'OFF'}")
# Cleanup
cap.release() cap.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()
landmarker.close()
print("Recognition stopped.") print("Recognition stopped.")

View File

@@ -35,188 +35,222 @@ else:
print("=" * 60) print("=" * 60)
# ===============================
# SELECTED LANDMARK INDICES
# ===============================
IMPORTANT_FACE_INDICES = sorted(list(set([
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
55, 65, 66, 105, 107, 336, 296, 334,
33, 133, 160, 159, 158, 144, 145, 153,
362, 263, 387, 386, 385, 373, 374, 380,
1, 2, 98, 327,
61, 185, 40, 39, 37, 0, 267, 269, 270, 409,
291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
308, 324, 318, 402, 317, 14, 87, 178, 88, 95
])))
NUM_FACE_POINTS = len(IMPORTANT_FACE_INDICES)
NUM_HAND_POINTS = 21 * 2
TOTAL_POINTS_PER_FRAME = NUM_HAND_POINTS + NUM_FACE_POINTS
# =============================== # ===============================
# DATA LOADING - HANDLES PARTIAL NaN # ENHANCED DATA EXTRACTION
# =============================== # ===============================
def load_kaggle_asl_data(base_path): def extract_multi_landmarks(path, min_valid_frames=5):
train_df = pd.read_csv(os.path.join(base_path, "train.csv")) """
with open(os.path.join(base_path, "sign_to_prediction_index_map.json")) as f: Extract both hands + selected face landmarks with modality flags
sign_to_idx = json.load(f) Returns: dict with 'landmarks', 'left_hand_valid', 'right_hand_valid', 'face_valid'
return train_df, sign_to_idx """
def extract_hand_landmarks_from_parquet(path):
"""Extract hand landmarks - ONLY uses frames with valid (non-NaN) data"""
try: try:
df = pd.read_parquet(path) df = pd.read_parquet(path)
seq = []
left_valid_frames = []
right_valid_frames = []
face_valid_frames = []
# Get hand data all_types = df["type"].unique()
left = df[df["type"] == "left_hand"] if "left_hand" in all_types or "right_hand" in all_types or "face" in all_types:
right = df[df["type"] == "right_hand"] frames = sorted(df["frame"].unique())
else:
if len(left) == 0 and len(right) == 0:
return None return None
# Count valid (non-NaN) rows for each hand if frames is None or len(frames) < min_valid_frames:
left_valid = 0
right_valid = 0
if len(left) > 0:
left_valid = left[['x', 'y', 'z']].notna().all(axis=1).sum()
if len(right) > 0:
right_valid = right[['x', 'y', 'z']].notna().all(axis=1).sum()
# No valid data at all
if left_valid == 0 and right_valid == 0:
return None return None
# Choose hand with more valid data
hand = left if left_valid >= right_valid else right
# Get unique frames
frames = sorted(hand['frame'].unique())
landmarks_seq = []
for frame in frames: for frame in frames:
lm_frame = hand[hand['frame'] == frame] frame_df = df[df["frame"] == frame]
frame_points = np.full((TOTAL_POINTS_PER_FRAME, 3), np.nan, dtype=np.float32)
# Count how many valid landmarks this frame has pos = 0
valid_count = lm_frame[['x', 'y', 'z']].notna().all(axis=1).sum() left_valid = False
right_valid = False
# Skip frames with too few valid landmarks face_valid = False
if valid_count < 10:
continue
# Extract landmarks for this frame
frame_landmarks = []
valid_landmarks_in_frame = 0
# Left hand
left = frame_df[frame_df["type"] == "left_hand"]
if len(left) >= 15:
valid_count = 0
for i in range(21): for i in range(21):
lm = lm_frame[lm_frame['landmark_index'] == i] row = left[left["landmark_index"] == i]
if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all():
if len(lm) == 0: frame_points[pos] = row[['x', 'y', 'z']].values[0]
frame_landmarks.append([0.0, 0.0, 0.0]) valid_count += 1
pos += 1
left_valid = (valid_count >= 15)
else: else:
x = float(lm['x'].iloc[0]) pos += 21
y = float(lm['y'].iloc[0])
z = float(lm['z'].iloc[0])
# Check if valid # Right hand
if pd.notna(x) and pd.notna(y) and pd.notna(z): right = frame_df[frame_df["type"] == "right_hand"]
frame_landmarks.append([x, y, z]) if len(right) >= 15:
valid_landmarks_in_frame += 1 valid_count = 0
for i in range(21):
row = right[right["landmark_index"] == i]
if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all():
frame_points[pos] = row[['x', 'y', 'z']].values[0]
valid_count += 1
pos += 1
right_valid = (valid_count >= 15)
else: else:
frame_landmarks.append([0.0, 0.0, 0.0]) pos += 21
# Only add frame if it has enough valid landmarks # Face
if valid_landmarks_in_frame >= 10: face = frame_df[frame_df["type"] == "face"]
landmarks_seq.append(frame_landmarks) if len(face) > 0:
valid_count = 0
for idx in IMPORTANT_FACE_INDICES:
row = face[face["landmark_index"] == idx]
if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all():
frame_points[pos] = row[['x', 'y', 'z']].values[0]
valid_count += 1
pos += 1
face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.5)
# Need at least 3 valid frames valid_ratio = 1 - np.isnan(frame_points).mean()
if len(landmarks_seq) < 3: if valid_ratio >= 0.40:
frame_points = np.nan_to_num(frame_points, nan=0.0)
seq.append(frame_points)
left_valid_frames.append(left_valid)
right_valid_frames.append(right_valid)
face_valid_frames.append(face_valid)
if len(seq) < min_valid_frames:
return None return None
return np.array(landmarks_seq, dtype=np.float32) return {
'landmarks': np.stack(seq),
'left_hand_valid': np.array(left_valid_frames),
'right_hand_valid': np.array(right_valid_frames),
'face_valid': np.array(face_valid_frames)
}
except Exception as e: except Exception:
return None return None
def get_features_sequence(landmarks_seq, max_frames=100): def get_features_sequence(landmarks_data, max_frames=100):
"""Extract features from landmark sequence""" """
if landmarks_seq is None or len(landmarks_seq) == 0: Enhanced feature extraction with separate modality processing
return None, None """
if landmarks_data is None:
# Center on wrist (landmark 0)
wrist = landmarks_seq[:, 0:1, :].copy()
landmarks_seq = landmarks_seq - wrist
# Scale normalization using wrist to middle finger MCP (landmark 9)
scale = np.linalg.norm(landmarks_seq[:, 9, :] - np.zeros(3), axis=1, keepdims=True)
scale = np.maximum(scale, 1e-6) # Avoid division by zero
landmarks_seq = landmarks_seq / scale[:, np.newaxis, :]
# Clean up any remaining NaN/Inf
landmarks_seq = np.nan_to_num(landmarks_seq, nan=0.0, posinf=0.0, neginf=0.0)
# Clip extreme values
landmarks_seq = np.clip(landmarks_seq, -10, 10)
# Calculate finger curl features
tips = [4, 8, 12, 16, 20] # Thumb, index, middle, ring, pinky tips
bases = [1, 5, 9, 13, 17] # Corresponding base joints
curl_features = []
for b, t in zip(bases, tips):
curl = np.linalg.norm(landmarks_seq[:, t] - landmarks_seq[:, b], axis=1)
curl_features.append(curl)
curl_features = np.stack(curl_features, axis=1) # (T, 5)
# Temporal deltas (motion)
deltas = np.zeros_like(landmarks_seq)
if len(landmarks_seq) > 1:
deltas[1:] = landmarks_seq[1:] - landmarks_seq[:-1]
# Flatten each component separately, then concatenate
landmarks_flat = landmarks_seq.reshape(landmarks_seq.shape[0], -1) # (T, 63)
deltas_flat = deltas.reshape(deltas.shape[0], -1) # (T, 63)
# curl_features is already (T, 5)
# Combine: 63 + 63 + 5 = 131 features per frame
seq = np.concatenate([
landmarks_flat,
deltas_flat,
curl_features
], axis=1)
# Pad or truncate to max_frames
T, F = seq.shape
if T < max_frames:
# Pad with zeros
pad = np.zeros((max_frames - T, F), dtype=np.float32)
seq = np.concatenate([seq, pad], axis=0)
elif T > max_frames:
# Truncate
seq = seq[:max_frames, :]
# Create attention mask (True for valid positions)
valid_mask = np.zeros(max_frames, dtype=bool)
valid_mask[:min(T, max_frames)] = True
return seq.astype(np.float32), valid_mask
def process_row(row, base_path, max_frames=100):
"""Process a single row - worker function for multiprocessing"""
path = os.path.join(base_path, row["path"])
if not os.path.exists(path):
return None, None, None return None, None, None
landmarks_3d = landmarks_data['landmarks']
if len(landmarks_3d) == 0:
return None, None, None
T, N, _ = landmarks_3d.shape
# Separate modalities for independent normalization
left_hand = landmarks_3d[:, :21, :]
right_hand = landmarks_3d[:, 21:42, :]
face = landmarks_3d[:, 42:, :]
# Independent centering per modality
features_list = []
for modality, valid_mask in [
(left_hand, landmarks_data['left_hand_valid']),
(right_hand, landmarks_data['right_hand_valid']),
(face, landmarks_data['face_valid'])
]:
# Center on modality-specific mean
valid_frames = modality[valid_mask] if valid_mask.any() else modality
if len(valid_frames) > 0:
center = np.mean(valid_frames, axis=(0, 1), keepdims=True)
spread = np.std(valid_frames, axis=(0, 1), keepdims=True).max()
else:
center = 0
spread = 1
modality_norm = (modality - center) / max(spread, 1e-6)
flat = modality_norm.reshape(T, -1)
# Deltas
deltas = np.zeros_like(flat)
if T > 1:
deltas[1:] = flat[1:] - flat[:-1]
features_list.append(flat)
features_list.append(deltas)
# Combine all modalities
features = np.concatenate(features_list, axis=1)
# Create modality availability mask (which frames have which modalities)
modality_mask = np.stack([
landmarks_data['left_hand_valid'],
landmarks_data['right_hand_valid'],
landmarks_data['face_valid']
], axis=1).astype(np.float32) # (T, 3)
# Pad/truncate
if T < max_frames:
pad = np.zeros((max_frames - T, features.shape[1]), dtype=np.float32)
features = np.concatenate([features, pad], axis=0)
mask_pad = np.zeros((max_frames - T, 3), dtype=np.float32)
modality_mask = np.concatenate([modality_mask, mask_pad], axis=0)
frame_mask = np.zeros(max_frames, dtype=bool)
frame_mask[:T] = True
else:
features = features[:max_frames]
modality_mask = modality_mask[:max_frames]
frame_mask = np.ones(max_frames, dtype=bool)
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
features = np.clip(features, -30, 30)
return features.astype(np.float32), frame_mask, modality_mask
def process_row(row_data, base_path, max_frames=100):
"""Process a single row - expects tuple of (path, sign)"""
path_rel, sign = row_data
path = os.path.join(base_path, path_rel)
if not os.path.exists(path):
return None, None, None, None
try: try:
# Extract landmarks lm_data = extract_multi_landmarks(path)
lm = extract_hand_landmarks_from_parquet(path) if lm_data is None:
if lm is None: return None, None, None, None
return None, None, None
# Get features feat, frame_mask, modality_mask = get_features_sequence(lm_data, max_frames)
feat, mask = get_features_sequence(lm, max_frames)
if feat is None: if feat is None:
return None, None, None return None, None, None, None
# Final safety check return feat, frame_mask, modality_mask, sign
if np.isnan(feat).any() or np.isinf(feat).any():
return None, None, None
return feat, mask, row["sign"] except Exception:
return None, None, None, None
except Exception as e:
return None, None, None
# =============================== # ===============================
# TRANSFORMER MODEL # ENHANCED MODEL WITH MODALITY AWARENESS
# =============================== # ===============================
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=128): def __init__(self, d_model, max_len=128):
@@ -232,16 +266,19 @@ class PositionalEncoding(nn.Module):
return x + self.pe[:, :x.size(1)] return x + self.pe[:, :x.size(1)]
class TransformerASL(nn.Module): class ModalityAwareTransformer(nn.Module):
def __init__(self, input_dim, num_classes, d_model=256, nhead=8, num_layers=4): def __init__(self, input_dim, num_classes, d_model=384, nhead=8, num_layers=5):
super().__init__() super().__init__()
# Input projection # Main projection
self.proj = nn.Linear(input_dim, d_model) self.proj = nn.Linear(input_dim, d_model)
self.norm_in = nn.LayerNorm(d_model)
self.pos = PositionalEncoding(d_model, max_len=128)
# Transformer encoder # Modality embedding (3 modalities: left_hand, right_hand, face)
self.modality_embed = nn.Linear(3, d_model)
self.norm_in = nn.LayerNorm(d_model)
self.pos = PositionalEncoding(d_model)
enc_layer = nn.TransformerEncoderLayer( enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, d_model=d_model,
nhead=nhead, nhead=nhead,
@@ -253,30 +290,45 @@ class TransformerASL(nn.Module):
) )
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
# Classification head
self.head = nn.Sequential( self.head = nn.Sequential(
nn.LayerNorm(d_model), nn.LayerNorm(d_model),
nn.Dropout(0.25), nn.Dropout(0.25),
nn.Linear(d_model, num_classes) nn.Linear(d_model, num_classes)
) )
def forward(self, x, key_padding_mask=None): def forward(self, x, modality_mask=None, key_padding_mask=None):
# x: (batch, seq_len, input_dim) # Project features
# key_padding_mask: (batch, seq_len) - True for padding positions
x = self.proj(x) x = self.proj(x)
# Add modality information if available
if modality_mask is not None:
mod_embed = self.modality_embed(modality_mask)
x = x + mod_embed
x = self.norm_in(x) x = self.norm_in(x)
x = self.pos(x) x = self.pos(x)
x = self.encoder(x, src_key_padding_mask=key_padding_mask) x = self.encoder(x, src_key_padding_mask=key_padding_mask)
# Global average pooling over valid positions # Weighted average (giving more weight to frames with more valid modalities)
if modality_mask is not None:
weights = modality_mask.sum(dim=-1, keepdim=True) + 1e-6 # (B, T, 1)
weights = weights / weights.sum(dim=1, keepdim=True)
x = (x * weights).sum(dim=1)
else:
x = x.mean(dim=1) x = x.mean(dim=1)
return self.head(x) return self.head(x)
def load_kaggle_asl_data(base_path):
"""Load training metadata"""
train_path = os.path.join(base_path, "train.csv")
train_df = pd.read_csv(train_path)
return train_df, None
# =============================== # ===============================
# MAIN TRAINING # MAIN
# =============================== # ===============================
def main(): def main():
base_path = "asl_kaggle" base_path = "asl_kaggle"
@@ -284,163 +336,142 @@ def main():
MIN_SAMPLES_PER_CLASS = 5 MIN_SAMPLES_PER_CLASS = 5
print("\nLoading metadata...") print("\nLoading metadata...")
train_df, sign_to_idx = load_kaggle_asl_data(base_path) train_df, _ = load_kaggle_asl_data(base_path)
print(f"Total sequences: {len(train_df)}")
rows = [row for _, row in train_df.iterrows()] # Convert to simple tuples for multiprocessing compatibility
rows = [(row['path'], row['sign']) for _, row in train_df.iterrows()]
print("\nProcessing sequences (this will take a few minutes)...") print("\nProcessing sequences with BOTH hands + FACE (enhanced)...")
print("Expected: ~36,000 valid sequences based on diagnostic")
# Process with multiprocessing
with Pool(cpu_count()) as pool: with Pool(cpu_count()) as pool:
results = list(tqdm( results = list(tqdm(
pool.imap( pool.imap(
partial(process_row, base_path=base_path, max_frames=max_frames), partial(process_row, base_path=base_path, max_frames=max_frames),
rows, rows,
chunksize=100 chunksize=80
), ),
total=len(rows), total=len(rows),
desc="Extracting landmarks" desc="Landmarks extraction"
)) ))
# Filter valid results X_list, frame_masks_list, modality_masks_list, y_list = [], [], [], []
X_list, masks_list, y_list = [], [], [] for feat, frame_mask, modality_mask, sign in results:
for feat, mask, sign in results: if feat is not None and frame_mask is not None:
if feat is not None and mask is not None and sign is not None:
if feat.shape[0] == max_frames:
X_list.append(feat) X_list.append(feat)
masks_list.append(mask) frame_masks_list.append(frame_mask)
modality_masks_list.append(modality_mask)
y_list.append(sign) y_list.append(sign)
print(f"\n✓ Successfully extracted: {len(X_list)} valid sequences") if not X_list:
print(f" Success rate: {len(X_list) / len(train_df) * 100:.1f}%") print("No valid sequences extracted!")
if len(X_list) < 100:
print("❌ Too few valid sequences found!")
print(" This shouldn't happen - please share this output for debugging")
return return
# Stack into arrays
X = np.stack(X_list) X = np.stack(X_list)
masks = np.stack(masks_list) frame_masks = np.stack(frame_masks_list)
modality_masks = np.stack(modality_masks_list)
print(f"\nData shape: {X.shape}") print(f"\nExtracted {len(X):,} sequences")
print(f"Feature dimension: {X.shape[2]}") print(f"Feature shape: {X.shape[1:]} (input_dim = {X.shape[2]})")
print(f"Modality mask shape: {modality_masks.shape}")
# Global normalization # Global normalization
print("Normalizing features...") X = np.clip(X, -30, 30)
X = np.clip(X, -10.0, 10.0)
mean = X.mean(axis=(0, 1), keepdims=True) mean = X.mean(axis=(0, 1), keepdims=True)
std = X.std(axis=(0, 1), keepdims=True) + 1e-8 std = X.std(axis=(0, 1), keepdims=True) + 1e-8
X = (X - mean) / std X = (X - mean) / std
# Encode labels # Labels
le = LabelEncoder() le = LabelEncoder()
y = le.fit_transform(y_list) y = le.fit_transform(y_list)
# Filter classes with too few samples # Filter rare classes
counts = Counter(y) counts = Counter(y)
valid_classes = [cls for cls, cnt in counts.items() if cnt >= MIN_SAMPLES_PER_CLASS] valid = [k for k, v in counts.items() if v >= MIN_SAMPLES_PER_CLASS]
mask_valid = np.isin(y, valid_classes) mask = np.isin(y, valid)
X = X[mask_valid] X = X[mask]
masks = masks[mask_valid] frame_masks = frame_masks[mask]
y = y[mask_valid] modality_masks = modality_masks[mask]
y = y[mask]
# Re-encode after filtering
le = LabelEncoder() le = LabelEncoder()
y = le.fit_transform(y) y = le.fit_transform(y)
print(f"\nFinal dataset after filtering:") print(f"After filtering: {len(X):,} samples | {len(le.classes_)} classes")
print(f" Samples: {len(X):,}")
print(f" Classes: {len(le.classes_)}")
print(f" Sign examples: {list(le.classes_[:10])}")
# Train-test split # Analyze modality usage
X_train, X_test, masks_train, masks_test, y_train, y_test = train_test_split( print("\nModality statistics:")
X, masks, y, test_size=0.15, stratify=y, random_state=42 print(f" Sequences with left hand: {(modality_masks[:, :, 0].sum(axis=1) > 0).mean() * 100:.1f}%")
print(f" Sequences with right hand: {(modality_masks[:, :, 1].sum(axis=1) > 0).mean() * 100:.1f}%")
print(f" Sequences with face: {(modality_masks[:, :, 2].sum(axis=1) > 0).mean() * 100:.1f}%")
# Split
X_tr, X_te, fm_tr, fm_te, mm_tr, mm_te, y_tr, y_te = train_test_split(
X, frame_masks, modality_masks, y, test_size=0.15, stratify=y, random_state=42
) )
print(f"\nTrain set: {len(X_train):,} samples") # Dataset
print(f"Test set: {len(X_test):,} samples") class ASLMultiDataset(Dataset):
def __init__(self, X, frame_masks, modality_masks, y):
# Dataset wrapper
class ASLSequenceDataset(Dataset):
def __init__(self, X, masks, y):
self.X = torch.from_numpy(X).float() self.X = torch.from_numpy(X).float()
self.masks = torch.from_numpy(masks) self.frame_masks = torch.from_numpy(frame_masks).bool()
self.modality_masks = torch.from_numpy(modality_masks).float()
self.y = torch.from_numpy(y).long() self.y = torch.from_numpy(y).long()
def __len__(self): def __len__(self):
return len(self.X) return len(self.X)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.X[idx], self.masks[idx], self.y[idx] return self.X[idx], self.frame_masks[idx], self.modality_masks[idx], self.y[idx]
# DataLoaders batch_size = 64 if device.type == 'cuda' else 32
batch_size = 128 if device.type == 'cuda' else 64
train_loader = DataLoader( train_loader = DataLoader(
ASLSequenceDataset(X_train, masks_train, y_train), ASLMultiDataset(X_tr, fm_tr, mm_tr, y_tr),
batch_size=batch_size, batch_size=batch_size, shuffle=True,
shuffle=True, num_workers=4, pin_memory=device.type == 'cuda'
num_workers=4,
pin_memory=True if device.type == 'cuda' else False
) )
test_loader = DataLoader( test_loader = DataLoader(
ASLSequenceDataset(X_test, masks_test, y_test), ASLMultiDataset(X_te, fm_te, mm_te, y_te),
batch_size=batch_size * 2, batch_size=batch_size * 2, shuffle=False,
shuffle=False, num_workers=4, pin_memory=device.type == 'cuda'
num_workers=4,
pin_memory=True if device.type == 'cuda' else False
) )
# Initialize model # Enhanced model
print("\nInitializing model...") model = ModalityAwareTransformer(
model = TransformerASL(
input_dim=X.shape[2], input_dim=X.shape[2],
num_classes=len(le.classes_), num_classes=len(le.classes_),
d_model=256, d_model=384,
nhead=8, nhead=8,
num_layers=4 num_layers=5
).to(device) ).to(device)
total_params = sum(p.numel() for p in model.parameters()) print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model parameters: {total_params:,}")
# Training setup
criterion = nn.CrossEntropyLoss(label_smoothing=0.05) criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4) optimizer = optim.AdamW(model.parameters(), lr=4e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
# Training loop
best_acc = 0.0 best_acc = 0.0
epochs = 60 epochs = 70
print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)
for epoch in range(epochs): for epoch in range(epochs):
# Train
model.train() model.train()
total_loss = 0 total_loss = correct = total = 0
correct = total = 0
for x, mask, yb in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False): for x, frame_mask, modality_mask, yb in tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False):
x, mask, yb = x.to(device), mask.to(device), yb.to(device) x = x.to(device)
frame_mask = frame_mask.to(device)
modality_mask = modality_mask.to(device)
yb = yb.to(device)
# Invert mask: True for padding positions key_padding_mask = ~frame_mask
key_mask = ~mask
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
logits = model(x, key_padding_mask=key_mask) logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask)
loss = criterion(logits, yb) loss = criterion(logits, yb)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
@@ -449,49 +480,35 @@ def main():
train_acc = correct / total * 100 train_acc = correct / total * 100
# Evaluate # Eval
model.eval() model.eval()
correct = total = 0 correct = total = 0
with torch.no_grad(): with torch.no_grad():
for x, mask, yb in test_loader: for x, frame_mask, modality_mask, yb in test_loader:
x, mask, yb = x.to(device), mask.to(device), yb.to(device) x = x.to(device)
key_mask = ~mask frame_mask = frame_mask.to(device)
logits = model(x, key_padding_mask=key_mask) modality_mask = modality_mask.to(device)
yb = yb.to(device)
key_padding_mask = ~frame_mask
logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask)
correct += (logits.argmax(-1) == yb).sum().item() correct += (logits.argmax(-1) == yb).sum().item()
total += yb.size(0) total += yb.size(0)
test_acc = correct / total * 100 test_acc = correct / total * 100
scheduler.step()
# Print progress
print(f"[{epoch + 1:2d}/{epochs}] Loss: {total_loss / len(train_loader):.4f} | " print(f"[{epoch + 1:2d}/{epochs}] Loss: {total_loss / len(train_loader):.4f} | "
f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}%", end="") f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}%", end="")
scheduler.step()
# Save best model
if test_acc > best_acc: if test_acc > best_acc:
best_acc = test_acc best_acc = test_acc
torch.save({ torch.save(model.state_dict(), "best_asl_modality_aware.pth")
'model': model.state_dict(), print(" → saved")
'optimizer': optimizer.state_dict(),
'label_encoder_classes': le.classes_,
'acc': best_acc,
'epoch': epoch,
'input_dim': X.shape[2],
'num_classes': len(le.classes_),
'd_model': 256,
'nhead': 8,
'num_layers': 4
}, "best_asl_transformer.pth")
print(f" → New best: {best_acc:.2f}% ✓")
else: else:
print() print()
print("\n" + "=" * 60) print(f"\nBest test accuracy: {best_acc:.2f}%")
print(f"✓ Training complete!")
print(f"✓ Best test accuracy: {best_acc:.2f}%")
print(f"✓ Model saved: best_asl_transformer.pth")
print("=" * 60)
if __name__ == "__main__": if __name__ == "__main__":