import os import polars as pl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader from concurrent.futures import ProcessPoolExecutor from tqdm import tqdm from sklearn.model_selection import train_test_split # --- CONFIG --- BASE_PATH = "asl_kaggle" TARGET_FRAMES = 22 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- DATA LOADING WITH RELATIVE FEATURES --- def load_file_to_memory(path, base_path=BASE_PATH): try: parquet_path = os.path.join(base_path, path) df = pl.read_parquet(parquet_path) # 1. Global Anchor (Nose) anchors = ( df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) .select([ pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz") ]) ) # 2. Local Anchors (Wrists) # Left: 468, Right: 522 wrists = ( df.filter(pl.col("landmark_index").is_in([468, 522])) .select([ pl.col("frame"), pl.col("landmark_index"), pl.col("x").alias("wx"), pl.col("y").alias("wy") ]) ) processed = df.join(anchors, on="frame", how="left") # Join wrist data to the main frame # We use a left join on frame and landmark_index to align wrist coords with their rows processed = ( processed.join(wrists, on=["frame", "landmark_index"], how="left") .with_columns([ # Global (Nose-relative) (pl.col("x") - pl.col("nx")).alias("x_g"), (pl.col("y") - pl.col("ny")).alias("y_g"), (pl.col("z") - pl.col("nz")).alias("z_g"), # Local (Wrist-relative - defaults to global if not a hand point) (pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"), (pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"), ]) .sort(["frame", "type", "landmark_index"]) ) # We now have 5 channels: (x_g, y_g, z_g, x_l, y_l) n_frames = processed["frame"].n_unique() # Reshape to (Frames, 543 landmarks, 5 features) tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5) # Temporal Resampling indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int) return tensor[indices] except Exception: return np.zeros((TARGET_FRAMES, 543, 5)) # --- DUAL-STREAM MODEL --- class ASLClassifier(nn.Module): def __init__(self, num_classes): super().__init__() # 543 landmarks * 5 features per landmark = 2715 self.feat_dim = 543 * 5 self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(512) self.conv2 = nn.Conv1d(512, 512, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(512) self.pool = nn.MaxPool1d(2) self.dropout = nn.Dropout(0.4) self.fc = nn.Sequential( nn.Linear(512, 1024), nn.ReLU(), nn.Dropout(0.2), nn.Linear(1024, num_classes) ) def forward(self, x): # x shape: (Batch, 22, 543, 5) b, t, l, f = x.shape # Flatten landmarks and features into one vector, then transpose for Conv1d x = x.view(b, t, -1).transpose(1, 2) # (Batch, 2715, 22) x = F.relu(self.bn1(self.conv1(x))) x = self.pool(x) x = F.relu(self.bn2(self.conv2(x))) x = self.pool(x) # Global Average Pool across the time dimension x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) return self.fc(self.dropout(x)) # --- EXECUTION --- if __name__ == "__main__": # 1. Setup Data metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv")) unique_signs = sorted(metadata["sign"].unique().to_list()) sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)} labels = [sign_to_idx[s] for s in metadata["sign"].to_list()] paths = metadata["path"].to_list() # 2. Load to RAM (Parallelized) print(f"Loading {len(paths)} files into RAM with 5-channel features...") with ProcessPoolExecutor() as executor: data_list = list(tqdm(executor.map(load_file_to_memory, paths), total=len(paths))) X = torch.tensor(np.array(data_list), dtype=torch.float32) y = torch.tensor(labels, dtype=torch.long) # 3. Split X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42) train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True) val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64) # 4. Train model = ASLClassifier(len(unique_signs)).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Helps prevent over-confidence print(f"Starting training on {device}...") for epoch in range(25): model.train() train_loss = 0 for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}"): batch_x, batch_y = batch_x.to(device), batch_y.to(device) optimizer.zero_grad() output = model(batch_x) loss = criterion(output, batch_y) loss.backward() optimizer.step() train_loss += loss.item() # Validation model.eval() correct, total = 0, 0 with torch.no_grad(): for vx, vy in val_loader: vx, vy = vx.to(device), vy.to(device) pred = model(vx).argmax(1) correct += (pred == vy).sum().item() total += vy.size(0) print(f"Epoch {epoch + 1} | Loss: {train_loss / len(train_loader):.4f} | Val Acc: {100 * correct / total:.2f}%") if (epoch + 1) % 5 == 0: torch.save(model.state_dict(), f"asl_model_v2_e{epoch + 1}.pth")