diff --git a/rewrite_training.py b/rewrite_training.py index 27bbe2f..df84352 100644 --- a/rewrite_training.py +++ b/rewrite_training.py @@ -4,166 +4,173 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader -from sklearn.model_selection import train_test_split +from torch.utils.data import TensorDataset, DataLoader +from concurrent.futures import ProcessPoolExecutor from tqdm import tqdm +from sklearn.model_selection import train_test_split -# --- CONFIGURATION --- +# --- CONFIG --- BASE_PATH = "asl_kaggle" TARGET_FRAMES = 22 -LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] -HANDS = list(range(468, 543)) -SELECTED_INDICES = LIPS + HANDS -NUM_FEATS = len(SELECTED_INDICES) * 3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# --- DATASET ENGINE --- +# --- DATA LOADING WITH RELATIVE FEATURES --- -def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): - parquet_path = os.path.join(base_path, path) - df = pl.read_parquet(parquet_path) +def load_file_to_memory(path, base_path=BASE_PATH): + try: + parquet_path = os.path.join(base_path, path) + df = pl.read_parquet(parquet_path) - # 1. Spatial Normalization - anchors = ( - df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) - .select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")]) - ) + # 1. Global Anchor (Nose) + anchors = ( + df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) + .select([ + pl.col("frame"), + pl.col("x").alias("nx"), + pl.col("y").alias("ny"), + pl.col("z").alias("nz") + ]) + ) - processed = ( - df.join(anchors, on="frame", how="left") - .with_columns([ - (pl.col("x") - pl.col("nx")).fill_null(0.0), - (pl.col("y") - pl.col("ny")).fill_null(0.0), - (pl.col("z") - pl.col("nz")).fill_null(0.0), - ]) - .sort(["frame", "type", "landmark_index"]) - ) + # 2. Local Anchors (Wrists) + # Left: 468, Right: 522 + wrists = ( + df.filter(pl.col("landmark_index").is_in([468, 522])) + .select([ + pl.col("frame"), + pl.col("landmark_index"), + pl.col("x").alias("wx"), + pl.col("y").alias("wy") + ]) + ) - # 2. Slice and Reshape - raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3) - reduced_tensor = raw_tensor[:, SELECTED_INDICES, :] + processed = df.join(anchors, on="frame", how="left") - # 3. Temporal Resampling - curr_len = reduced_tensor.shape[0] - indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int) - return reduced_tensor[indices] + # Join wrist data to the main frame + # We use a left join on frame and landmark_index to align wrist coords with their rows + processed = ( + processed.join(wrists, on=["frame", "landmark_index"], how="left") + .with_columns([ + # Global (Nose-relative) + (pl.col("x") - pl.col("nx")).alias("x_g"), + (pl.col("y") - pl.col("ny")).alias("y_g"), + (pl.col("z") - pl.col("nz")).alias("z_g"), + # Local (Wrist-relative - defaults to global if not a hand point) + (pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"), + (pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"), + ]) + .sort(["frame", "type", "landmark_index"]) + ) + + # We now have 5 channels: (x_g, y_g, z_g, x_l, y_l) + n_frames = processed["frame"].n_unique() + # Reshape to (Frames, 543 landmarks, 5 features) + tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5) + + # Temporal Resampling + indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int) + return tensor[indices] + except Exception: + return np.zeros((TARGET_FRAMES, 543, 5)) -class ASLDataset(Dataset): - def __init__(self, paths, labels): - self.paths = paths - self.labels = labels - - def __len__(self): - return len(self.paths) - - def __getitem__(self, idx): - try: - x = load_and_preprocess(self.paths[idx]) - y = self.labels[idx] - return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long) - except Exception as e: - # Return a zero tensor if a file is corrupted to prevent crash - return torch.zeros((TARGET_FRAMES, len(SELECTED_INDICES), 3)), torch.tensor(self.labels[idx], - dtype=torch.long) - - -# --- MODEL --- +# --- DUAL-STREAM MODEL --- class ASLClassifier(nn.Module): def __init__(self, num_classes): super().__init__() - self.conv1 = nn.Conv1d(NUM_FEATS, 256, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm1d(256) - self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1) + # 543 landmarks * 5 features per landmark = 2715 + self.feat_dim = 543 * 5 + + self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm1d(512) + self.conv2 = nn.Conv1d(512, 512, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(512) + self.pool = nn.MaxPool1d(2) - self.dropout = nn.Dropout(0.5) - self.fc = nn.Linear(512, num_classes) + self.dropout = nn.Dropout(0.4) + + self.fc = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, num_classes) + ) def forward(self, x): - # x shape: (Batch, 22, 96, 3) - b, t, l, c = x.shape - x = x.view(b, t, -1).transpose(1, 2) # (Batch, Features, Time) + # x shape: (Batch, 22, 543, 5) + b, t, l, f = x.shape + # Flatten landmarks and features into one vector, then transpose for Conv1d + x = x.view(b, t, -1).transpose(1, 2) # (Batch, 2715, 22) x = F.relu(self.bn1(self.conv1(x))) x = self.pool(x) + x = F.relu(self.bn2(self.conv2(x))) x = self.pool(x) + # Global Average Pool across the time dimension x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) - x = self.dropout(x) - return self.fc(x) + + return self.fc(self.dropout(x)) -# --- TRAINING LOOP --- +# --- EXECUTION --- -def run_training(): - # 1. Prepare Metadata - train_df = pl.read_csv(os.path.join(BASE_PATH, "train.csv")) - unique_signs = sorted(train_df["sign"].unique().to_list()) +if __name__ == "__main__": + # 1. Setup Data + metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv")) + unique_signs = sorted(metadata["sign"].unique().to_list()) sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)} + labels = [sign_to_idx[s] for s in metadata["sign"].to_list()] - paths = train_df["path"].to_list() - labels = [sign_to_idx[s] for s in train_df["sign"].to_list()] + paths = metadata["path"].to_list() - # 2. Split - p_train, p_val, l_train, l_val = train_test_split(paths, labels, test_size=0.15, stratify=labels) + # 2. Load to RAM (Parallelized) + print(f"Loading {len(paths)} files into RAM with 5-channel features...") + with ProcessPoolExecutor() as executor: + data_list = list(tqdm(executor.map(load_file_to_memory, paths), total=len(paths))) - # 3. Loaders - train_ds = ASLDataset(p_train, l_train) - val_ds = ASLDataset(p_val, l_val) + X = torch.tensor(np.array(data_list), dtype=torch.float32) + y = torch.tensor(labels, dtype=torch.long) - # num_workers=4 allows the CPU to preprocess the next batch while GPU trains - train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True) - val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) + # 3. Split + X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42) - # 4. Init Model & Optim + train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True) + val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64) + + # 4. Train model = ASLClassifier(len(unique_signs)).to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - criterion = nn.CrossEntropyLoss(label_smoothing=0.1) - scaler = torch.amp.GradScaler(enabled=(device.type == 'cuda')) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Helps prevent over-confidence - # 5. Loop - for epoch in range(30): + print(f"Starting training on {device}...") + for epoch in range(25): model.train() - t_correct, t_total = 0, 0 - - pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") - for x, y in pbar: - x, y = x.to(device), y.to(device) + train_loss = 0 + for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}"): + batch_x, batch_y = batch_x.to(device), batch_y.to(device) optimizer.zero_grad() - - # Use Mixed Precision for speed - with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')): - outputs = model(x) - loss = criterion(outputs, y) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - - _, pred = torch.max(outputs, 1) - t_total += y.size(0) - t_correct += (pred == y).sum().item() - pbar.set_postfix(acc=f"{(t_correct / t_total) * 100:.1f}%") + output = model(batch_x) + loss = criterion(output, batch_y) + loss.backward() + optimizer.step() + train_loss += loss.item() # Validation model.eval() - v_correct, v_total = 0, 0 + correct, total = 0, 0 with torch.no_grad(): - for x, y in val_loader: - x, y = x.to(device), y.to(device) - outputs = model(x) - _, pred = torch.max(outputs, 1) - v_total += y.size(0) - v_correct += (pred == y).sum().item() + for vx, vy in val_loader: + vx, vy = vx.to(device), vy.to(device) + pred = model(vx).argmax(1) + correct += (pred == vy).sum().item() + total += vy.size(0) - print(f"Validation Accuracy: {(v_correct / v_total) * 100:.2f}%") - torch.save(model.state_dict(), f"asl_model_epoch_{epoch}.pth") + print(f"Epoch {epoch + 1} | Loss: {train_loss / len(train_loader):.4f} | Val Acc: {100 * correct / total:.2f}%") - -if __name__ == "__main__": - run_training() \ No newline at end of file + if (epoch + 1) % 5 == 0: + torch.save(model.state_dict(), f"asl_model_v2_e{epoch + 1}.pth") \ No newline at end of file