diff --git a/rewrite_training.py b/rewrite_training.py new file mode 100644 index 0000000..8e46a5e --- /dev/null +++ b/rewrite_training.py @@ -0,0 +1,112 @@ +import os +import polars as pl +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from concurrent.futures import ProcessPoolExecutor +from tqdm import tqdm + +# --- CONFIGURATION --- +BASE_PATH = "asl_kaggle" +TARGET_FRAMES = 22 +# Hand landmarks + Lip landmarks (approximate indices for high-value face points) +LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] +HANDS = list(range(468, 543)) +SELECTED_INDICES = LIPS + HANDS +NUM_FEATS = len(SELECTED_INDICES) * 3 # X, Y, Z for each selected point +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# --- DATA PROCESSING --- + +def load_kaggle_metadata(base_path): + return pl.read_csv(os.path.join(base_path, "train.csv")) + + +def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES): + parquet_path = os.path.join(base_path, path) + df = pl.read_parquet(parquet_path) + + # 1. Spatial Normalization (Nose Anchor) + anchors = ( + df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0)) + .select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")]) + ) + + processed = ( + df.join(anchors, on="frame", how="left") + .with_columns([ + (pl.col("x") - pl.col("nx")).fill_null(0.0), + (pl.col("y") - pl.col("ny")).fill_null(0.0), + (pl.col("z") - pl.col("nz")).fill_null(0.0), + ]) + .sort(["frame", "type", "landmark_index"]) + ) + + # 2. Reshape & Feature Selection + # Get unique frames and total landmarks (543) + raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3) + + # Slice to keep only Hands and Lips + reduced_tensor = raw_tensor[:, SELECTED_INDICES, :] + + # 3. Temporal Normalization (Resample to fixed frame count) + curr_len = reduced_tensor.shape[0] + indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int) + return reduced_tensor[indices] + + +# --- MODEL ARCHITECTURE --- + +class ASLClassifier(nn.Module): + def __init__(self, num_classes, target_frames=TARGET_FRAMES, num_feats=NUM_FEATS): + super().__init__() + self.conv1 = nn.Conv1d(num_feats, 256, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm1d(256) + self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm1d(512) + self.pool = nn.MaxPool1d(2) + self.dropout = nn.Dropout(0.5) + self.fc = nn.Linear(512, num_classes) + + def forward(self, x): + # x: (Batch, Frames, Selected_Landmarks, 3) + x = x.view(x.shape[0], x.shape[1], -1) # Flatten landmarks/coords + x = x.transpose(1, 2) # (Batch, Features, Time) + + x = F.relu(self.bn1(self.conv1(x))) + x = self.pool(x) + x = F.relu(self.bn2(self.conv2(x))) + x = self.pool(x) + + x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) + x = self.dropout(x) + return self.fc(x) + + +# --- EXECUTION --- + +if __name__ == "__main__": + asl_data = load_kaggle_metadata(BASE_PATH) + + # Optimization: Process 100 samples to get a feel for the shape/speed + # Using multiprocessing to avoid the slow single-thread loop + paths = asl_data["path"].to_list() + + print(f"Processing {len(paths)} files in parallel...") + with ProcessPoolExecutor() as executor: + results = list(tqdm(executor.map(load_and_preprocess, paths), total=len(paths))) + + # Stack into one giant Torch tensor + dataset_tensor = torch.tensor(np.array(results), dtype=torch.float32) + print(f"Final Tensor Shape: {dataset_tensor.shape}") + # Shape: (100, 22, 96, 3) -> (Batch, Time, Landmarks, Coords) + + # Initialize Model + num_unique_signs = asl_data["sign"].n_unique() + model = ASLClassifier(num_classes=num_unique_signs) + model.to(device) + # Test pass + output = model(dataset_tensor) + print(f"Model Output Shape: {output.shape}") # (100, 250) \ No newline at end of file