import os import polars as pl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split from tqdm import tqdm # --- CONFIGURATION --- BASE_PATH = "asl_kaggle" TARGET_FRAMES = 22 LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] HANDS = list(range(468, 543)) SELECTED_INDICES = LIPS + HANDS NUM_FEATS = len(SELECTED_INDICES) * 3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- DATASET ENGINE --- def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): parquet_path = os.path.join(base_path, path) df = pl.read_parquet(parquet_path) # 1. Spatial Normalization anchors = ( df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) .select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")]) ) processed = ( df.join(anchors, on="frame", how="left") .with_columns([ (pl.col("x") - pl.col("nx")).fill_null(0.0), (pl.col("y") - pl.col("ny")).fill_null(0.0), (pl.col("z") - pl.col("nz")).fill_null(0.0), ]) .sort(["frame", "type", "landmark_index"]) ) # 2. Slice and Reshape raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3) reduced_tensor = raw_tensor[:, SELECTED_INDICES, :] # 3. Temporal Resampling curr_len = reduced_tensor.shape[0] indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int) return reduced_tensor[indices] class ASLDataset(Dataset): def __init__(self, paths, labels): self.paths = paths self.labels = labels def __len__(self): return len(self.paths) def __getitem__(self, idx): try: x = load_and_preprocess(self.paths[idx]) y = self.labels[idx] return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long) except Exception as e: # Return a zero tensor if a file is corrupted to prevent crash return torch.zeros((TARGET_FRAMES, len(SELECTED_INDICES), 3)), torch.tensor(self.labels[idx], dtype=torch.long) # --- MODEL --- class ASLClassifier(nn.Module): def __init__(self, num_classes): super().__init__() self.conv1 = nn.Conv1d(NUM_FEATS, 256, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(256) self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(512) self.pool = nn.MaxPool1d(2) self.dropout = nn.Dropout(0.5) self.fc = nn.Linear(512, num_classes) def forward(self, x): # x shape: (Batch, 22, 96, 3) b, t, l, c = x.shape x = x.view(b, t, -1).transpose(1, 2) # (Batch, Features, Time) x = F.relu(self.bn1(self.conv1(x))) x = self.pool(x) x = F.relu(self.bn2(self.conv2(x))) x = self.pool(x) x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) x = self.dropout(x) return self.fc(x) # --- TRAINING LOOP --- def run_training(): # 1. Prepare Metadata train_df = pl.read_csv(os.path.join(BASE_PATH, "train.csv")) unique_signs = sorted(train_df["sign"].unique().to_list()) sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)} paths = train_df["path"].to_list() labels = [sign_to_idx[s] for s in train_df["sign"].to_list()] # 2. Split p_train, p_val, l_train, l_val = train_test_split(paths, labels, test_size=0.15, stratify=labels) # 3. Loaders train_ds = ASLDataset(p_train, l_train) val_ds = ASLDataset(p_val, l_val) # num_workers=4 allows the CPU to preprocess the next batch while GPU trains train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) # 4. Init Model & Optim model = ASLClassifier(len(unique_signs)).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) scaler = torch.amp.GradScaler(enabled=(device.type == 'cuda')) # 5. Loop for epoch in range(30): model.train() t_correct, t_total = 0, 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") for x, y in pbar: x, y = x.to(device), y.to(device) optimizer.zero_grad() # Use Mixed Precision for speed with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')): outputs = model(x) loss = criterion(outputs, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() _, pred = torch.max(outputs, 1) t_total += y.size(0) t_correct += (pred == y).sum().item() pbar.set_postfix(acc=f"{(t_correct / t_total) * 100:.1f}%") # Validation model.eval() v_correct, v_total = 0, 0 with torch.no_grad(): for x, y in val_loader: x, y = x.to(device), y.to(device) outputs = model(x) _, pred = torch.max(outputs, 1) v_total += y.size(0) v_correct += (pred == y).sum().item() print(f"Validation Accuracy: {(v_correct / v_total) * 100:.2f}%") torch.save(model.state_dict(), f"asl_model_epoch_{epoch}.pth") if __name__ == "__main__": run_training()