Files
ASLTranslator/rewrite_training.py
2026-01-26 13:17:30 -06:00

369 lines
14 KiB
Python

import os
import polars as pl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from concurrent.futures import ProcessPoolExecutor
# --- CONFIG ---
BASE_PATH = "asl_kaggle"
CACHE_DIR = "asl_cache"
TARGET_FRAMES = 22
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Landmark indices for each body part (based on MediaPipe Holistic)
LANDMARK_RANGES = {
'left_hand': list(range(468, 489)), # 21 landmarks
'right_hand': list(range(522, 543)), # 21 landmarks
'pose': list(range(489, 522)), # 33 landmarks
'face': list(range(0, 468)) # 468 landmarks
}
# --- PREPROCESSING (RUN ONCE) ---
def process_single_file(args):
"""Process a single file - designed for multiprocessing"""
i, path, base_path, cache_dir = args
cache_path = os.path.join(cache_dir, f"sample_{i}.npz")
if os.path.exists(cache_path):
return # Skip if already cached
try:
parquet_path = os.path.join(base_path, path)
df = pl.read_parquet(parquet_path)
# Global Anchor (Nose)
anchors = (
df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0))
.select([
pl.col("frame"),
pl.col("x").alias("nx"),
pl.col("y").alias("ny"),
pl.col("z").alias("nz")
])
)
# Local Anchors (Wrists - landmark_index 468 and 522)
wrists = (
df.filter(pl.col("landmark_index").is_in([468, 522]))
.select([
pl.col("frame"),
pl.col("landmark_index"),
pl.col("x").alias("wx"),
pl.col("y").alias("wy")
])
)
processed = df.join(anchors, on="frame", how="left")
processed = (
processed.join(wrists, on=["frame", "landmark_index"], how="left")
.with_columns([
(pl.col("x") - pl.col("nx")).alias("x_g"),
(pl.col("y") - pl.col("ny")).alias("y_g"),
(pl.col("z") - pl.col("nz")).alias("z_g"),
(pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"),
(pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"),
])
.sort(["frame", "type", "landmark_index"])
)
n_frames = processed["frame"].n_unique()
# Full tensor for indexing
full_tensor = processed.select(["landmark_index", "x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy()
full_tensor = full_tensor.reshape(n_frames, 543, 6) # 6 = 1 (index) + 5 (features)
# Extract landmark_index and features separately
full_data = full_tensor[:, :, 1:] # Remove landmark_index column (features only)
# Temporal Resampling
indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int)
resampled = full_data[indices] # (22, 543, 5)
# Split by body part
left_hand = resampled[:, LANDMARK_RANGES['left_hand'], :] # (22, 21, 5)
right_hand = resampled[:, LANDMARK_RANGES['right_hand'], :] # (22, 21, 5)
pose = resampled[:, LANDMARK_RANGES['pose'], :] # (22, 33, 5)
face = resampled[:, LANDMARK_RANGES['face'], :] # (22, 468, 5)
# Save as compressed npz
np.savez_compressed(cache_path,
left_hand=left_hand.astype(np.float32),
right_hand=right_hand.astype(np.float32),
pose=pose.astype(np.float32),
face=face.astype(np.float32))
except Exception as e:
# Save zero tensors for failed files
np.savez_compressed(cache_path,
left_hand=np.zeros((TARGET_FRAMES, 21, 5), dtype=np.float32),
right_hand=np.zeros((TARGET_FRAMES, 21, 5), dtype=np.float32),
pose=np.zeros((TARGET_FRAMES, 33, 5), dtype=np.float32),
face=np.zeros((TARGET_FRAMES, 468, 5), dtype=np.float32))
def preprocess_and_cache(paths, base_path=BASE_PATH, cache_dir=CACHE_DIR):
"""Preprocess all files in parallel and save as numpy arrays"""
os.makedirs(cache_dir, exist_ok=True)
# Check if already cached
all_cached = all(os.path.exists(os.path.join(cache_dir, f"sample_{i}.npz")) for i in range(len(paths)))
if all_cached:
print("All files already cached, skipping preprocessing...")
return
print(f"Preprocessing {len(paths)} files in parallel...")
# Create arguments for each file
args_list = [(i, path, base_path, cache_dir) for i, path in enumerate(paths)]
# Process in parallel
with ProcessPoolExecutor() as executor:
list(tqdm(executor.map(process_single_file, args_list), total=len(args_list)))
print("Preprocessing complete!")
# --- FAST DATASET (LOADS FROM CACHE) ---
class CachedASLDataset(Dataset):
"""Fast dataset that loads from preprocessed numpy files"""
def __init__(self, indices, labels, cache_dir=CACHE_DIR):
self.indices = indices
self.labels = labels
self.cache_dir = cache_dir
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
sample_idx = self.indices[idx]
cache_path = os.path.join(self.cache_dir, f"sample_{sample_idx}.npz")
# Fast numpy load
data = np.load(cache_path)
# Load each body part
left_hand = torch.tensor(data['left_hand'], dtype=torch.float32)
right_hand = torch.tensor(data['right_hand'], dtype=torch.float32)
pose = torch.tensor(data['pose'], dtype=torch.float32)
face = torch.tensor(data['face'], dtype=torch.float32)
label = torch.tensor(self.labels[idx], dtype=torch.long)
return (left_hand, right_hand, pose, face), label
# --- IMPROVED MODEL WITH SEPARATE STREAMS ---
class ASLClassifierSeparateStreams(nn.Module):
def __init__(self, num_classes, dropout_rate=0.3):
super().__init__()
# Feature dimensions for each body part
self.left_hand_dim = 21 * 5 # 21 landmarks * 5 features
self.right_hand_dim = 21 * 5 # 21 landmarks * 5 features
self.pose_dim = 33 * 5 # 33 landmarks * 5 features
self.face_dim = 468 * 5 # 468 landmarks * 5 features
# Separate convolutional streams for each body part
self.left_hand_stream = self._make_conv_stream(self.left_hand_dim, 128, dropout_rate)
self.right_hand_stream = self._make_conv_stream(self.right_hand_dim, 128, dropout_rate)
self.pose_stream = self._make_conv_stream(self.pose_dim, 128, dropout_rate)
self.face_stream = self._make_conv_stream(self.face_dim, 256, dropout_rate)
# Combined features dimension
combined_dim = 128 + 128 + 128 + 256 # 640
# Bidirectional LSTM for temporal modeling
self.lstm = nn.LSTM(
input_size=combined_dim,
hidden_size=256,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=dropout_rate
)
# Attention mechanism
self.attention = nn.Sequential(
nn.Linear(512, 128), # 512 from bidirectional LSTM (256*2)
nn.Tanh(),
nn.Linear(128, 1)
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, num_classes)
)
def _make_conv_stream(self, input_dim, output_dim, dropout_rate):
"""Create a convolutional stream for processing a body part"""
return nn.Sequential(
nn.Conv1d(input_dim, output_dim * 2, kernel_size=3, padding=1),
nn.BatchNorm1d(output_dim * 2),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.5),
nn.Conv1d(output_dim * 2, output_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(output_dim),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.5)
)
def forward(self, x):
# x is a tuple: (left_hand, right_hand, pose, face)
left_hand, right_hand, pose, face = x
b = left_hand.shape[0] # batch size
# Flatten landmarks and features for each body part
# Shape: (batch, time, landmarks, features) -> (batch, landmarks*features, time)
left_hand = left_hand.view(b, TARGET_FRAMES, -1).transpose(1, 2)
right_hand = right_hand.view(b, TARGET_FRAMES, -1).transpose(1, 2)
pose = pose.view(b, TARGET_FRAMES, -1).transpose(1, 2)
face = face.view(b, TARGET_FRAMES, -1).transpose(1, 2)
# Process each body part through its stream
left_hand_feat = self.left_hand_stream(left_hand) # (batch, 128, time)
right_hand_feat = self.right_hand_stream(right_hand) # (batch, 128, time)
pose_feat = self.pose_stream(pose) # (batch, 128, time)
face_feat = self.face_stream(face) # (batch, 256, time)
# Transpose back: (batch, features, time) -> (batch, time, features)
left_hand_feat = left_hand_feat.transpose(1, 2)
right_hand_feat = right_hand_feat.transpose(1, 2)
pose_feat = pose_feat.transpose(1, 2)
face_feat = face_feat.transpose(1, 2)
# Concatenate all features
combined = torch.cat([left_hand_feat, right_hand_feat, pose_feat, face_feat], dim=2)
# combined shape: (batch, time, 640)
# LSTM processing
lstm_out, _ = self.lstm(combined) # (batch, time, 512)
# Attention mechanism
attention_weights = F.softmax(self.attention(lstm_out), dim=1) # (batch, time, 1)
attended = torch.sum(attention_weights * lstm_out, dim=1) # (batch, 512)
# Classification
return self.classifier(attended)
# --- EXECUTION ---
if __name__ == "__main__":
# 1. Setup Metadata
metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv"))
unique_signs = sorted(metadata["sign"].unique().to_list())
sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)}
paths = metadata["path"].to_list()
labels = [sign_to_idx[s] for s in metadata["sign"].to_list()]
# 2. Preprocess and cache (parallelized, only runs if cache doesn't exist)
preprocess_and_cache(paths)
# 3. Create index mapping for train/val split
all_indices = list(range(len(paths)))
train_indices, val_indices, train_labels, val_labels = train_test_split(
all_indices, labels, test_size=0.1, stratify=labels, random_state=42
)
# 4. Create datasets from cached files
train_dataset = CachedASLDataset(train_indices, train_labels)
val_dataset = CachedASLDataset(val_indices, val_labels)
# Adjust batch size based on GPU memory (separate streams use more memory)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=4, pin_memory=True)
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
# 5. Train
model = ASLClassifierSeparateStreams(len(unique_signs)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
best_acc = 0.0
print(f"Starting training on {device}...")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
for epoch in range(25):
# Training
model.train()
train_loss = 0
train_correct = 0
train_total = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/25 [Train]")
for batch_data, batch_y in pbar:
# Move all body parts to device
batch_data = tuple(d.to(device) for d in batch_data)
batch_y = batch_y.to(device)
optimizer.zero_grad()
output = model(batch_data)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(output, 1)
train_total += batch_y.size(0)
train_correct += (predicted == batch_y).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100 * train_correct / train_total:.1f}%'})
# Validation
model.eval()
val_correct, val_total = 0, 0
val_loss = 0
with torch.no_grad():
for batch_data, vy in tqdm(val_loader, desc=f"Epoch {epoch + 1}/25 [Val]"):
batch_data = tuple(d.to(device) for d in batch_data)
vy = vy.to(device)
output = model(batch_data)
val_loss += criterion(output, vy).item()
pred = output.argmax(1)
val_correct += (pred == vy).sum().item()
val_total += vy.size(0)
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
train_acc = 100 * train_correct / train_total
val_acc = 100 * val_correct / val_total
scheduler.step(avg_val_loss)
print(f"\nEpoch {epoch + 1}/25:")
print(f" Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f" Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%")
# Save best model
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_asl_model_separate.pth")
print(f" ✓ Best model saved! (Val Acc: {val_acc:.2f}%)\n")
# Checkpoint every 5 epochs
if (epoch + 1) % 5 == 0:
torch.save(model.state_dict(), f"asl_model_separate_e{epoch + 1}.pth")
print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")