Underdog
This commit is contained in:
@@ -1,237 +1,759 @@
|
|||||||
import os
|
import os
|
||||||
import polars as pl
|
import json
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset, DataLoader, random_split
|
import torch.optim as optim
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
from functools import partial
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
# --- CONFIGURATION ---
|
# ===============================
|
||||||
BASE_PATH = "asl_kaggle"
|
# GPU CONFIGURATION
|
||||||
TARGET_FRAMES = 22
|
# ===============================
|
||||||
LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95]
|
print("=" * 60)
|
||||||
HANDS = list(range(468, 543))
|
print("GPU CONFIGURATION")
|
||||||
SELECTED_INDICES = LIPS + HANDS
|
print("=" * 60)
|
||||||
NUM_FEATS = len(SELECTED_INDICES) * 3
|
|
||||||
|
|
||||||
# Training hyperparameters
|
if torch.cuda.is_available():
|
||||||
BATCH_SIZE = 32
|
print(f"✓ CUDA available!")
|
||||||
EPOCHS = 50
|
print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
|
||||||
LEARNING_RATE = 0.001
|
device = torch.device('cuda:0')
|
||||||
TRAIN_SPLIT = 0.8
|
torch.backends.cudnn.benchmark = True
|
||||||
CHECKPOINT_DIR = "checkpoints"
|
torch.backends.cudnn.enabled = True
|
||||||
|
else:
|
||||||
|
print("✗ CUDA not available, using CPU")
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
print("=" * 60)
|
||||||
print(f"Using device: {device}")
|
|
||||||
|
# ===============================
|
||||||
|
# SELECTED LANDMARK INDICES
|
||||||
|
# ===============================
|
||||||
|
IMPORTANT_FACE_INDICES = sorted(list(set([
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||||
|
55, 65, 66, 105, 107, 336, 296, 334,
|
||||||
|
33, 133, 160, 159, 158, 144, 145, 153,
|
||||||
|
362, 263, 387, 386, 385, 373, 374, 380,
|
||||||
|
1, 2, 98, 327,
|
||||||
|
61, 185, 40, 39, 37, 0, 267, 269, 270, 409,
|
||||||
|
291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
|
||||||
|
78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
|
||||||
|
308, 324, 318, 402, 317, 14, 87, 178, 88, 95
|
||||||
|
])))
|
||||||
|
|
||||||
|
NUM_FACE_POINTS = len(IMPORTANT_FACE_INDICES)
|
||||||
|
NUM_HAND_POINTS = 21 * 2
|
||||||
|
TOTAL_POINTS_PER_FRAME = NUM_HAND_POINTS + NUM_FACE_POINTS
|
||||||
|
|
||||||
|
|
||||||
# --- DATA PROCESSING ---
|
# ===============================
|
||||||
def load_kaggle_metadata(base_path):
|
# DATA AUGMENTATION
|
||||||
return pl.read_csv(os.path.join(base_path, "train.csv"))
|
# ===============================
|
||||||
|
def augment_sequence(x, modality_mask):
|
||||||
|
"""Apply random augmentations to training data"""
|
||||||
|
x = x.copy()
|
||||||
|
|
||||||
|
# Random temporal cropping (simulate different signing speeds)
|
||||||
|
if np.random.rand() < 0.3 and len(x) > 20:
|
||||||
|
start = np.random.randint(0, max(1, len(x) // 4))
|
||||||
|
x = x[start:]
|
||||||
|
modality_mask = modality_mask[start:]
|
||||||
|
|
||||||
|
# Random spatial scaling
|
||||||
|
if np.random.rand() < 0.5:
|
||||||
|
scale = np.random.uniform(0.85, 1.15)
|
||||||
|
x = x * scale
|
||||||
|
|
||||||
|
# Random rotation (around z-axis for x,y coordinates)
|
||||||
|
if np.random.rand() < 0.5:
|
||||||
|
angle = np.random.uniform(-0.3, 0.3)
|
||||||
|
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
||||||
|
|
||||||
|
# Reshape to get xyz coordinates
|
||||||
|
x_reshaped = x.reshape(len(x), -1, 3)
|
||||||
|
x_rot = x_reshaped.copy()
|
||||||
|
x_rot[..., 0] = x_reshaped[..., 0] * cos_a - x_reshaped[..., 1] * sin_a
|
||||||
|
x_rot[..., 1] = x_reshaped[..., 0] * sin_a + x_reshaped[..., 1] * cos_a
|
||||||
|
x = x_rot.reshape(x.shape)
|
||||||
|
|
||||||
|
# Random masking (simulate occlusion) - only for some frames
|
||||||
|
if np.random.rand() < 0.3:
|
||||||
|
n_mask = int(len(x) * 0.15) # mask 15% of frames
|
||||||
|
mask_indices = np.random.choice(len(x), n_mask, replace=False)
|
||||||
|
x[mask_indices] *= 0.1 # dim but don't completely zero
|
||||||
|
|
||||||
|
# Random noise
|
||||||
|
if np.random.rand() < 0.4:
|
||||||
|
noise = np.random.normal(0, 0.02, x.shape)
|
||||||
|
x = x + noise
|
||||||
|
|
||||||
|
# Random time warping (speed up or slow down)
|
||||||
|
if np.random.rand() < 0.3 and len(x) > 20:
|
||||||
|
speed = np.random.uniform(0.8, 1.2)
|
||||||
|
new_len = int(len(x) * speed)
|
||||||
|
new_len = min(new_len, len(x))
|
||||||
|
indices = np.linspace(0, len(x) - 1, new_len).astype(int)
|
||||||
|
x = x[indices]
|
||||||
|
modality_mask = modality_mask[indices]
|
||||||
|
|
||||||
|
return x, modality_mask
|
||||||
|
|
||||||
|
|
||||||
def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES):
|
# ===============================
|
||||||
parquet_path = os.path.join(base_path, path)
|
# ENHANCED DATA EXTRACTION (POLARS)
|
||||||
df = pl.read_parquet(parquet_path)
|
# ===============================
|
||||||
|
def extract_multi_landmarks(path, min_valid_frames=3):
|
||||||
|
"""
|
||||||
|
Extract both hands + selected face landmarks with modality flags
|
||||||
|
Returns: dict with 'landmarks', 'left_hand_valid', 'right_hand_valid', 'face_valid'
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
df = pl.read_parquet(path)
|
||||||
|
seq = []
|
||||||
|
left_valid_frames = []
|
||||||
|
right_valid_frames = []
|
||||||
|
face_valid_frames = []
|
||||||
|
|
||||||
anchors = (
|
all_types = df.select("type").unique().to_series().to_list()
|
||||||
df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0))
|
has_data = any(t in all_types for t in ["left_hand", "right_hand", "face"])
|
||||||
.select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")])
|
|
||||||
)
|
|
||||||
|
|
||||||
processed = (
|
if not has_data:
|
||||||
df.join(anchors, on="frame", how="left")
|
return None
|
||||||
.with_columns([
|
|
||||||
(pl.col("x") - pl.col("nx")).fill_null(0.0),
|
|
||||||
(pl.col("y") - pl.col("ny")).fill_null(0.0),
|
|
||||||
(pl.col("z") - pl.col("nz")).fill_null(0.0),
|
|
||||||
])
|
|
||||||
.sort(["frame", "type", "landmark_index"])
|
|
||||||
)
|
|
||||||
|
|
||||||
raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3)
|
frames = sorted(df.select("frame").unique().to_series().to_list())
|
||||||
reduced_tensor = raw_tensor[:, SELECTED_INDICES, :]
|
|
||||||
|
|
||||||
curr_len = reduced_tensor.shape[0]
|
if len(frames) < min_valid_frames:
|
||||||
indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int)
|
return None
|
||||||
return reduced_tensor[indices]
|
|
||||||
|
for frame in frames:
|
||||||
|
frame_df = df.filter(pl.col("frame") == frame)
|
||||||
|
frame_points = np.full((TOTAL_POINTS_PER_FRAME, 3), np.nan, dtype=np.float32)
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
left_valid = False
|
||||||
|
right_valid = False
|
||||||
|
face_valid = False
|
||||||
|
|
||||||
|
# Left hand
|
||||||
|
left = frame_df.filter(pl.col("type") == "left_hand")
|
||||||
|
if left.height > 0:
|
||||||
|
valid_count = 0
|
||||||
|
for i in range(21):
|
||||||
|
row = left.filter(pl.col("landmark_index") == i)
|
||||||
|
if row.height > 0:
|
||||||
|
coords = row.select(["x", "y", "z"]).row(0)
|
||||||
|
if all(c is not None for c in coords):
|
||||||
|
frame_points[pos] = coords
|
||||||
|
valid_count += 1
|
||||||
|
pos += 1
|
||||||
|
left_valid = (valid_count >= 10)
|
||||||
|
else:
|
||||||
|
pos += 21
|
||||||
|
|
||||||
|
# Right hand
|
||||||
|
right = frame_df.filter(pl.col("type") == "right_hand")
|
||||||
|
if right.height > 0:
|
||||||
|
valid_count = 0
|
||||||
|
for i in range(21):
|
||||||
|
row = right.filter(pl.col("landmark_index") == i)
|
||||||
|
if row.height > 0:
|
||||||
|
coords = row.select(["x", "y", "z"]).row(0)
|
||||||
|
if all(c is not None for c in coords):
|
||||||
|
frame_points[pos] = coords
|
||||||
|
valid_count += 1
|
||||||
|
pos += 1
|
||||||
|
right_valid = (valid_count >= 10)
|
||||||
|
else:
|
||||||
|
pos += 21
|
||||||
|
|
||||||
|
# Face
|
||||||
|
face = frame_df.filter(pl.col("type") == "face")
|
||||||
|
if face.height > 0:
|
||||||
|
valid_count = 0
|
||||||
|
for idx in IMPORTANT_FACE_INDICES:
|
||||||
|
row = face.filter(pl.col("landmark_index") == idx)
|
||||||
|
if row.height > 0:
|
||||||
|
coords = row.select(["x", "y", "z"]).row(0)
|
||||||
|
if all(c is not None for c in coords):
|
||||||
|
frame_points[pos] = coords
|
||||||
|
valid_count += 1
|
||||||
|
pos += 1
|
||||||
|
face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.3)
|
||||||
|
|
||||||
|
valid_ratio = 1 - np.isnan(frame_points).mean()
|
||||||
|
if valid_ratio >= 0.20:
|
||||||
|
frame_points = np.nan_to_num(frame_points, nan=0.0)
|
||||||
|
seq.append(frame_points)
|
||||||
|
left_valid_frames.append(left_valid)
|
||||||
|
right_valid_frames.append(right_valid)
|
||||||
|
face_valid_frames.append(face_valid)
|
||||||
|
|
||||||
|
if len(seq) < min_valid_frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'landmarks': np.stack(seq),
|
||||||
|
'left_hand_valid': np.array(left_valid_frames),
|
||||||
|
'right_hand_valid': np.array(right_valid_frames),
|
||||||
|
'face_valid': np.array(face_valid_frames)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# --- DATASET CLASS ---
|
def get_features_sequence(landmarks_data, max_frames=100):
|
||||||
class ASLDataset(Dataset):
|
"""Enhanced feature extraction with separate modality processing"""
|
||||||
def __init__(self, tensors, labels):
|
if landmarks_data is None:
|
||||||
self.tensors = tensors
|
return None, None, None
|
||||||
self.labels = labels
|
|
||||||
|
|
||||||
def __len__(self):
|
landmarks_3d = landmarks_data['landmarks']
|
||||||
return len(self.tensors)
|
if len(landmarks_3d) == 0:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
T, N, _ = landmarks_3d.shape
|
||||||
return self.tensors[idx], self.labels[idx]
|
|
||||||
|
# Separate modalities for independent normalization
|
||||||
|
left_hand = landmarks_3d[:, :21, :]
|
||||||
|
right_hand = landmarks_3d[:, 21:42, :]
|
||||||
|
face = landmarks_3d[:, 42:, :]
|
||||||
|
|
||||||
|
features_list = []
|
||||||
|
|
||||||
|
for modality, valid_mask in [
|
||||||
|
(left_hand, landmarks_data['left_hand_valid']),
|
||||||
|
(right_hand, landmarks_data['right_hand_valid']),
|
||||||
|
(face, landmarks_data['face_valid'])
|
||||||
|
]:
|
||||||
|
valid_frames = modality[valid_mask] if valid_mask.any() else modality
|
||||||
|
if len(valid_frames) > 0:
|
||||||
|
center = np.mean(valid_frames, axis=(0, 1), keepdims=True)
|
||||||
|
spread = np.std(valid_frames, axis=(0, 1), keepdims=True).max()
|
||||||
|
else:
|
||||||
|
center = 0
|
||||||
|
spread = 1
|
||||||
|
|
||||||
|
modality_norm = (modality - center) / max(spread, 1e-6)
|
||||||
|
flat = modality_norm.reshape(T, -1)
|
||||||
|
|
||||||
|
# Deltas
|
||||||
|
deltas = np.zeros_like(flat)
|
||||||
|
if T > 1:
|
||||||
|
deltas[1:] = flat[1:] - flat[:-1]
|
||||||
|
|
||||||
|
features_list.append(flat)
|
||||||
|
features_list.append(deltas)
|
||||||
|
|
||||||
|
features = np.concatenate(features_list, axis=1)
|
||||||
|
|
||||||
|
modality_mask = np.stack([
|
||||||
|
landmarks_data['left_hand_valid'],
|
||||||
|
landmarks_data['right_hand_valid'],
|
||||||
|
landmarks_data['face_valid']
|
||||||
|
], axis=1).astype(np.float32)
|
||||||
|
|
||||||
|
# Pad/truncate
|
||||||
|
if T < max_frames:
|
||||||
|
pad = np.zeros((max_frames - T, features.shape[1]), dtype=np.float32)
|
||||||
|
features = np.concatenate([features, pad], axis=0)
|
||||||
|
|
||||||
|
mask_pad = np.zeros((max_frames - T, 3), dtype=np.float32)
|
||||||
|
modality_mask = np.concatenate([modality_mask, mask_pad], axis=0)
|
||||||
|
|
||||||
|
frame_mask = np.zeros(max_frames, dtype=bool)
|
||||||
|
frame_mask[:T] = True
|
||||||
|
else:
|
||||||
|
features = features[:max_frames]
|
||||||
|
modality_mask = modality_mask[:max_frames]
|
||||||
|
frame_mask = np.ones(max_frames, dtype=bool)
|
||||||
|
|
||||||
|
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
features = np.clip(features, -30, 30)
|
||||||
|
|
||||||
|
return features.astype(np.float32), frame_mask, modality_mask
|
||||||
|
|
||||||
|
|
||||||
# --- MODEL ARCHITECTURE ---
|
def process_row(row_data, base_path, max_frames=100):
|
||||||
class ASLClassifier(nn.Module):
|
"""Process a single row"""
|
||||||
def __init__(self, num_classes, target_frames=TARGET_FRAMES, num_feats=NUM_FEATS):
|
path_rel, sign = row_data
|
||||||
|
path = os.path.join(base_path, path_rel)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
lm_data = extract_multi_landmarks(path)
|
||||||
|
if lm_data is None:
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
feat, frame_mask, modality_mask = get_features_sequence(lm_data, max_frames)
|
||||||
|
if feat is None:
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
return feat, frame_mask, modality_mask, sign
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# MIXUP AUGMENTATION
|
||||||
|
# ===============================
|
||||||
|
def mixup_data(x, frame_mask, modality_mask, y, alpha=0.2):
|
||||||
|
"""Mixup augmentation"""
|
||||||
|
if alpha > 0:
|
||||||
|
lam = np.random.beta(alpha, alpha)
|
||||||
|
else:
|
||||||
|
lam = 1
|
||||||
|
|
||||||
|
batch_size = x.size(0)
|
||||||
|
index = torch.randperm(batch_size).to(x.device)
|
||||||
|
|
||||||
|
mixed_x = lam * x + (1 - lam) * x[index]
|
||||||
|
mixed_frame_mask = frame_mask | frame_mask[index] # Union of valid frames
|
||||||
|
mixed_modality_mask = torch.max(modality_mask, modality_mask[index])
|
||||||
|
|
||||||
|
y_a, y_b = y, y[index]
|
||||||
|
return mixed_x, mixed_frame_mask, mixed_modality_mask, y_a, y_b, lam
|
||||||
|
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# ENHANCED MODEL WITH ATTENTION POOLING
|
||||||
|
# ===============================
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_model, max_len=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv1d(num_feats, 256, kernel_size=3, padding=1)
|
pe = torch.zeros(max_len, d_model)
|
||||||
self.bn1 = nn.BatchNorm1d(256)
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||||
self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||||
self.bn2 = nn.BatchNorm1d(512)
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
self.pool = nn.MaxPool1d(2)
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
self.dropout = nn.Dropout(0.5)
|
self.register_buffer('pe', pe.unsqueeze(0))
|
||||||
self.fc = nn.Linear(512, num_classes)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.view(x.shape[0], x.shape[1], -1)
|
return x + self.pe[:, :x.size(1)]
|
||||||
x = x.transpose(1, 2)
|
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
|
||||||
x = self.pool(x)
|
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
|
||||||
x = self.pool(x)
|
|
||||||
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
|
|
||||||
x = self.dropout(x)
|
|
||||||
return self.fc(x)
|
|
||||||
|
|
||||||
|
|
||||||
# --- TRAINING FUNCTIONS ---
|
class ModalityAwareTransformer(nn.Module):
|
||||||
def train_epoch(model, dataloader, criterion, optimizer, device):
|
def __init__(self, input_dim, num_classes, d_model=512, nhead=8, num_layers=6, dropout=0.15):
|
||||||
model.train()
|
super().__init__()
|
||||||
running_loss = 0.0
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
progress_bar = tqdm(dataloader, desc="Training")
|
# Main projection
|
||||||
for inputs, labels in progress_bar:
|
self.proj = nn.Linear(input_dim, d_model)
|
||||||
inputs, labels = inputs.to(device), labels.to(device)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
# Modality embedding (3 modalities: left_hand, right_hand, face)
|
||||||
outputs = model(inputs)
|
self.modality_embed = nn.Linear(3, d_model)
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
self.norm_in = nn.LayerNorm(d_model)
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
self.pos = PositionalEncoding(d_model)
|
||||||
total += labels.size(0)
|
|
||||||
correct += (predicted == labels).sum().item()
|
|
||||||
|
|
||||||
progress_bar.set_postfix({
|
enc_layer = nn.TransformerEncoderLayer(
|
||||||
'loss': running_loss / (progress_bar.n + 1),
|
d_model=d_model,
|
||||||
'acc': 100 * correct / total
|
nhead=nhead,
|
||||||
})
|
dim_feedforward=d_model * 4,
|
||||||
|
dropout=dropout,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
|
||||||
|
|
||||||
epoch_loss = running_loss / len(dataloader)
|
# Attention pooling
|
||||||
epoch_acc = 100 * correct / total
|
self.attention_pool = nn.Linear(d_model, 1)
|
||||||
return epoch_loss, epoch_acc
|
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.LayerNorm(d_model),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(d_model, d_model // 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(d_model // 2, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, modality_mask=None, key_padding_mask=None):
|
||||||
|
# Project features
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
# Add modality information
|
||||||
|
if modality_mask is not None:
|
||||||
|
mod_embed = self.modality_embed(modality_mask)
|
||||||
|
x = x + mod_embed
|
||||||
|
|
||||||
|
x = self.norm_in(x)
|
||||||
|
x = self.pos(x)
|
||||||
|
x = self.encoder(x, src_key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
# Attention-based pooling
|
||||||
|
attn_weights = self.attention_pool(x) # (B, T, 1)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(-1), -1e9)
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=1)
|
||||||
|
x = (x * attn_weights).sum(dim=1)
|
||||||
|
|
||||||
|
return self.head(x)
|
||||||
|
|
||||||
|
|
||||||
def validate(model, dataloader, criterion, device):
|
def load_kaggle_asl_data(base_path):
|
||||||
model.eval()
|
"""Load training metadata using Polars"""
|
||||||
running_loss = 0.0
|
train_path = os.path.join(base_path, "train.csv")
|
||||||
correct = 0
|
train_df = pl.read_csv(train_path)
|
||||||
total = 0
|
return train_df, None
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for inputs, labels in tqdm(dataloader, desc="Validation"):
|
|
||||||
inputs, labels = inputs.to(device), labels.to(device)
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
|
||||||
total += labels.size(0)
|
|
||||||
correct += (predicted == labels).sum().item()
|
|
||||||
|
|
||||||
val_loss = running_loss / len(dataloader)
|
|
||||||
val_acc = 100 * correct / total
|
|
||||||
return val_loss, val_acc
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, val_acc, checkpoint_dir):
|
# ===============================
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
# DATASET WITH AUGMENTATION
|
||||||
checkpoint = {
|
# ===============================
|
||||||
'epoch': epoch,
|
class ASLMultiDataset(Dataset):
|
||||||
'model_state_dict': model.state_dict(),
|
def __init__(self, X, frame_masks, modality_masks, y, training=False, max_frames=100):
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
self.X = X
|
||||||
'train_loss': train_loss,
|
self.frame_masks = frame_masks
|
||||||
'val_loss': val_loss,
|
self.modality_masks = modality_masks
|
||||||
'val_acc': val_acc,
|
self.y = y
|
||||||
}
|
self.training = training
|
||||||
path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
|
self.max_frames = max_frames
|
||||||
torch.save(checkpoint, path)
|
|
||||||
print(f"Checkpoint saved: {path}")
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
x = self.X[idx].copy()
|
||||||
|
frame_mask = self.frame_masks[idx].copy()
|
||||||
|
modality_mask = self.modality_masks[idx].copy()
|
||||||
|
y = self.y[idx]
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# Apply augmentation
|
||||||
|
x, modality_mask = augment_sequence(x, modality_mask)
|
||||||
|
|
||||||
|
# Re-pad if needed after augmentation
|
||||||
|
if len(x) < self.max_frames:
|
||||||
|
pad = np.zeros((self.max_frames - len(x), x.shape[1]), dtype=np.float32)
|
||||||
|
x = np.concatenate([x, pad], axis=0)
|
||||||
|
|
||||||
|
mask_pad = np.zeros((self.max_frames - len(x), 3), dtype=np.float32)
|
||||||
|
modality_mask = np.concatenate([modality_mask, mask_pad], axis=0)
|
||||||
|
|
||||||
|
frame_mask = np.zeros(self.max_frames, dtype=bool)
|
||||||
|
frame_mask[:len(x)] = True
|
||||||
|
else:
|
||||||
|
x = x[:self.max_frames]
|
||||||
|
modality_mask = modality_mask[:self.max_frames]
|
||||||
|
frame_mask = np.ones(self.max_frames, dtype=bool)
|
||||||
|
|
||||||
|
return (
|
||||||
|
torch.from_numpy(x).float(),
|
||||||
|
torch.from_numpy(frame_mask).bool(),
|
||||||
|
torch.from_numpy(modality_mask).float(),
|
||||||
|
torch.tensor(y, dtype=torch.long)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# TRAINING SINGLE MODEL
|
||||||
|
# ===============================
|
||||||
|
def train_model(X_tr, fm_tr, mm_tr, y_tr, X_te, fm_te, mm_te, y_te,
|
||||||
|
num_classes, input_dim, model_idx=0, epochs=80):
|
||||||
|
"""Train a single model"""
|
||||||
|
|
||||||
|
# Set different seed for each model
|
||||||
|
torch.manual_seed(42 + model_idx)
|
||||||
|
np.random.seed(42 + model_idx)
|
||||||
|
|
||||||
|
batch_size = 64 if device.type == 'cuda' else 32
|
||||||
|
|
||||||
|
train_dataset = ASLMultiDataset(X_tr, fm_tr, mm_tr, y_tr, training=True, max_frames=100)
|
||||||
|
test_dataset = ASLMultiDataset(X_te, fm_te, mm_te, y_te, training=False, max_frames=100)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size, shuffle=True,
|
||||||
|
num_workers=4, pin_memory=device.type == 'cuda'
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=batch_size * 2, shuffle=False,
|
||||||
|
num_workers=4, pin_memory=device.type == 'cuda'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enhanced model
|
||||||
|
model = ModalityAwareTransformer(
|
||||||
|
input_dim=input_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
num_layers=6,
|
||||||
|
dropout=0.15
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(f"\n[Model {model_idx + 1}] Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||||
|
|
||||||
|
# OneCycleLR scheduler
|
||||||
|
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=1e-3,
|
||||||
|
steps_per_epoch=len(train_loader),
|
||||||
|
epochs=epochs,
|
||||||
|
pct_start=0.1,
|
||||||
|
anneal_strategy='cos'
|
||||||
|
)
|
||||||
|
|
||||||
|
best_acc = 0.0
|
||||||
|
save_path = f"best_asl_model_{model_idx}.pth"
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"TRAINING MODEL {model_idx + 1}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = correct = total = 0
|
||||||
|
|
||||||
|
for x, frame_mask, modality_mask, yb in tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False):
|
||||||
|
x = x.to(device)
|
||||||
|
frame_mask = frame_mask.to(device)
|
||||||
|
modality_mask = modality_mask.to(device)
|
||||||
|
yb = yb.to(device)
|
||||||
|
|
||||||
|
# Apply mixup
|
||||||
|
if np.random.rand() < 0.5:
|
||||||
|
x, frame_mask, modality_mask, y_a, y_b, lam = mixup_data(
|
||||||
|
x, frame_mask, modality_mask, yb, alpha=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
key_padding_mask = ~frame_mask
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask)
|
||||||
|
loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
|
||||||
|
|
||||||
|
# Use original labels for accuracy
|
||||||
|
correct += (logits.argmax(-1) == yb).sum().item()
|
||||||
|
else:
|
||||||
|
key_padding_mask = ~frame_mask
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask)
|
||||||
|
loss = criterion(logits, yb)
|
||||||
|
correct += (logits.argmax(-1) == yb).sum().item()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
total += yb.size(0)
|
||||||
|
|
||||||
|
train_acc = correct / total * 100
|
||||||
|
|
||||||
|
# Eval
|
||||||
|
model.eval()
|
||||||
|
correct = total = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for x, frame_mask, modality_mask, yb in test_loader:
|
||||||
|
x = x.to(device)
|
||||||
|
frame_mask = frame_mask.to(device)
|
||||||
|
modality_mask = modality_mask.to(device)
|
||||||
|
yb = yb.to(device)
|
||||||
|
|
||||||
|
key_padding_mask = ~frame_mask
|
||||||
|
logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask)
|
||||||
|
correct += (logits.argmax(-1) == yb).sum().item()
|
||||||
|
total += yb.size(0)
|
||||||
|
|
||||||
|
test_acc = correct / total * 100
|
||||||
|
|
||||||
|
print(f"[{epoch + 1:2d}/{epochs}] Loss: {total_loss / len(train_loader):.4f} | "
|
||||||
|
f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}%", end="")
|
||||||
|
|
||||||
|
if test_acc > best_acc:
|
||||||
|
best_acc = test_acc
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
print(" → saved")
|
||||||
|
else:
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"\nModel {model_idx + 1} - Best test accuracy: {best_acc:.2f}%")
|
||||||
|
return save_path, best_acc
|
||||||
|
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# ENSEMBLE PREDICTION
|
||||||
|
# ===============================
|
||||||
|
def ensemble_predict(model_paths, test_loader, num_classes, input_dim):
|
||||||
|
"""Make predictions using ensemble of models"""
|
||||||
|
all_preds = []
|
||||||
|
|
||||||
|
for model_path in model_paths:
|
||||||
|
model = ModalityAwareTransformer(
|
||||||
|
input_dim=input_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
num_layers=6
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for x, frame_mask, modality_mask, _ in test_loader:
|
||||||
|
x = x.to(device)
|
||||||
|
frame_mask = frame_mask.to(device)
|
||||||
|
modality_mask = modality_mask.to(device)
|
||||||
|
|
||||||
|
key_padding_mask = ~frame_mask
|
||||||
|
logits = model(x, modality_mask=modality_mask, key_padding_mask=key_padding_mask)
|
||||||
|
preds.append(F.softmax(logits, dim=-1))
|
||||||
|
|
||||||
|
all_preds.append(torch.cat(preds, dim=0))
|
||||||
|
|
||||||
|
# Average predictions
|
||||||
|
ensemble_pred = torch.stack(all_preds).mean(0)
|
||||||
|
return ensemble_pred.argmax(-1).cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# MAIN
|
||||||
|
# ===============================
|
||||||
|
def main():
|
||||||
|
base_path = "asl_kaggle"
|
||||||
|
max_frames = 100
|
||||||
|
MIN_SAMPLES_PER_CLASS = 3 # Relaxed from 5
|
||||||
|
NUM_ENSEMBLE_MODELS = 3
|
||||||
|
EPOCHS = 80
|
||||||
|
|
||||||
|
print("\nLoading metadata...")
|
||||||
|
train_df, _ = load_kaggle_asl_data(base_path)
|
||||||
|
print(f"Total samples in train.csv: {train_df.height}")
|
||||||
|
|
||||||
|
rows = [(row[0], row[1]) for row in train_df.select(["path", "sign"]).iter_rows()]
|
||||||
|
|
||||||
|
print("\nProcessing sequences with BOTH hands + FACE (enhanced)...")
|
||||||
|
print("This may take a few minutes...")
|
||||||
|
|
||||||
|
with Pool(cpu_count()) as pool:
|
||||||
|
results = list(tqdm(
|
||||||
|
pool.imap(
|
||||||
|
partial(process_row, base_path=base_path, max_frames=max_frames),
|
||||||
|
rows,
|
||||||
|
chunksize=80
|
||||||
|
),
|
||||||
|
total=len(rows),
|
||||||
|
desc="Landmarks extraction"
|
||||||
|
))
|
||||||
|
|
||||||
|
X_list, frame_masks_list, modality_masks_list, y_list = [], [], [], []
|
||||||
|
failed_count = 0
|
||||||
|
for feat, frame_mask, modality_mask, sign in results:
|
||||||
|
if feat is not None and frame_mask is not None:
|
||||||
|
X_list.append(feat)
|
||||||
|
frame_masks_list.append(frame_mask)
|
||||||
|
modality_masks_list.append(modality_mask)
|
||||||
|
y_list.append(sign)
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
|
||||||
|
if not X_list:
|
||||||
|
print(f"\n❌ No valid sequences extracted!")
|
||||||
|
print(f"Failed to process: {failed_count}/{len(results)} files")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n✓ Successfully processed: {len(X_list)}/{len(results)} files")
|
||||||
|
print(f"✗ Failed: {failed_count}/{len(results)} files")
|
||||||
|
|
||||||
|
X = np.stack(X_list)
|
||||||
|
frame_masks = np.stack(frame_masks_list)
|
||||||
|
modality_masks = np.stack(modality_masks_list)
|
||||||
|
|
||||||
|
print(f"\nExtracted {len(X):,} sequences")
|
||||||
|
print(f"Feature shape: {X.shape[1:]} (input_dim = {X.shape[2]})")
|
||||||
|
|
||||||
|
# Global normalization
|
||||||
|
X = np.clip(X, -30, 30)
|
||||||
|
mean = X.mean(axis=(0, 1), keepdims=True)
|
||||||
|
std = X.std(axis=(0, 1), keepdims=True) + 1e-8
|
||||||
|
X = (X - mean) / std
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
le = LabelEncoder()
|
||||||
|
y = le.fit_transform(y_list)
|
||||||
|
|
||||||
|
# Filter rare classes
|
||||||
|
counts = Counter(y)
|
||||||
|
valid = [k for k, v in counts.items() if v >= MIN_SAMPLES_PER_CLASS]
|
||||||
|
mask = np.isin(y, valid)
|
||||||
|
|
||||||
|
X = X[mask]
|
||||||
|
frame_masks = frame_masks[mask]
|
||||||
|
modality_masks = modality_masks[mask]
|
||||||
|
y = y[mask]
|
||||||
|
|
||||||
|
le = LabelEncoder()
|
||||||
|
y = le.fit_transform(y)
|
||||||
|
|
||||||
|
print(f"After filtering: {len(X):,} samples | {len(le.classes_)} classes")
|
||||||
|
|
||||||
|
# Split
|
||||||
|
X_tr, X_te, fm_tr, fm_te, mm_tr, mm_te, y_tr, y_te = train_test_split(
|
||||||
|
X, frame_masks, modality_masks, y, test_size=0.15, stratify=y, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train ensemble of models
|
||||||
|
model_paths = []
|
||||||
|
best_accs = []
|
||||||
|
|
||||||
|
for i in range(NUM_ENSEMBLE_MODELS):
|
||||||
|
model_path, best_acc = train_model(
|
||||||
|
X_tr, fm_tr, mm_tr, y_tr,
|
||||||
|
X_te, fm_te, mm_te, y_te,
|
||||||
|
num_classes=len(le.classes_),
|
||||||
|
input_dim=X.shape[2],
|
||||||
|
model_idx=i,
|
||||||
|
epochs=EPOCHS
|
||||||
|
)
|
||||||
|
model_paths.append(model_path)
|
||||||
|
best_accs.append(best_acc)
|
||||||
|
|
||||||
|
# Ensemble evaluation
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ENSEMBLE EVALUATION")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
test_dataset = ASLMultiDataset(X_te, fm_te, mm_te, y_te, training=False)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=128,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=device.type == 'cuda'
|
||||||
|
)
|
||||||
|
|
||||||
|
ensemble_preds = ensemble_predict(model_paths, test_loader, len(le.classes_), X.shape[2])
|
||||||
|
ensemble_acc = (ensemble_preds == y_te).mean() * 100
|
||||||
|
|
||||||
|
print(f"\nIndividual model accuracies:")
|
||||||
|
for i, acc in enumerate(best_accs):
|
||||||
|
print(f" Model {i + 1}: {acc:.2f}%")
|
||||||
|
|
||||||
|
print(f"\n🎯 Ensemble accuracy: {ensemble_acc:.2f}%")
|
||||||
|
print(f" Improvement: +{ensemble_acc - max(best_accs):.2f}% over best single model")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"TRAINING COMPLETE")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
# --- EXECUTION ---
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Load metadata
|
main()
|
||||||
asl_data = load_kaggle_metadata(BASE_PATH)
|
|
||||||
|
|
||||||
# Create label mapping
|
|
||||||
unique_signs = sorted(asl_data["sign"].unique().to_list())
|
|
||||||
label_to_idx = {sign: idx for idx, sign in enumerate(unique_signs)}
|
|
||||||
labels = torch.tensor([label_to_idx[sign] for sign in asl_data["sign"].to_list()])
|
|
||||||
|
|
||||||
print(f"Number of classes: {len(unique_signs)}")
|
|
||||||
|
|
||||||
# Process data in parallel
|
|
||||||
paths = asl_data["path"].to_list()
|
|
||||||
print(f"Processing {len(paths)} files in parallel...")
|
|
||||||
|
|
||||||
with ProcessPoolExecutor() as executor:
|
|
||||||
results = list(tqdm(executor.map(load_and_preprocess, paths), total=len(paths)))
|
|
||||||
|
|
||||||
dataset_tensor = torch.tensor(np.array(results), dtype=torch.float32)
|
|
||||||
print(f"Final Tensor Shape: {dataset_tensor.shape}")
|
|
||||||
|
|
||||||
# Create dataset and split
|
|
||||||
full_dataset = ASLDataset(dataset_tensor, labels)
|
|
||||||
train_size = int(TRAIN_SPLIT * len(full_dataset))
|
|
||||||
val_size = len(full_dataset) - train_size
|
|
||||||
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
|
||||||
|
|
||||||
print(f"Train samples: {train_size}, Validation samples: {val_size}")
|
|
||||||
|
|
||||||
# Initialize model, loss, optimizer
|
|
||||||
model = ASLClassifier(num_classes=len(unique_signs)).to(device)
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
best_val_acc = 0.0
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("Starting Training")
|
|
||||||
print("=" * 50 + "\n")
|
|
||||||
|
|
||||||
for epoch in range(EPOCHS):
|
|
||||||
print(f"\nEpoch [{epoch + 1}/{EPOCHS}]")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
|
||||||
val_loss, val_acc = validate(model, val_loader, criterion, device)
|
|
||||||
|
|
||||||
scheduler.step(val_loss)
|
|
||||||
|
|
||||||
print(f"\nEpoch {epoch + 1} Summary:")
|
|
||||||
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
|
|
||||||
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
|
|
||||||
print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
|
|
||||||
|
|
||||||
# Save checkpoint if validation accuracy improved
|
|
||||||
if val_acc > best_val_acc:
|
|
||||||
best_val_acc = val_acc
|
|
||||||
save_checkpoint(model, optimizer, epoch + 1, train_loss, val_loss, val_acc, CHECKPOINT_DIR)
|
|
||||||
print(f" ✓ New best validation accuracy: {best_val_acc:.2f}%")
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("Training Complete!")
|
|
||||||
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
|
|
||||||
print("=" * 50)
|
|
||||||
Reference in New Issue
Block a user