From 2cbd178ba83f11a0b2d82661bb06f671488a94c2 Mon Sep 17 00:00:00 2001 From: Stupdi Go Date: Mon, 19 Jan 2026 23:21:35 -0600 Subject: [PATCH] soy cooked pt 2 --- training.py | 70 +++++++++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 32 deletions(-) diff --git a/training.py b/training.py index b76ee76..03aaec9 100644 --- a/training.py +++ b/training.py @@ -2,7 +2,7 @@ import os import json import math import numpy as np -import pandas as pd +import polars as pl import torch import torch.nn as nn import torch.nn.functional as F @@ -56,7 +56,7 @@ TOTAL_POINTS_PER_FRAME = NUM_HAND_POINTS + NUM_FACE_POINTS # =============================== -# ENHANCED DATA EXTRACTION (FIXED) +# ENHANCED DATA EXTRACTION (POLARS) # =============================== def extract_multi_landmarks(path, min_valid_frames=3): """ @@ -64,13 +64,13 @@ def extract_multi_landmarks(path, min_valid_frames=3): Returns: dict with 'landmarks', 'left_hand_valid', 'right_hand_valid', 'face_valid' """ try: - df = pd.read_parquet(path) + df = pl.read_parquet(path) seq = [] left_valid_frames = [] right_valid_frames = [] face_valid_frames = [] - all_types = df["type"].unique() + all_types = df.select("type").unique().to_series().to_list() # Check if we have at least one of the required types has_data = any(t in all_types for t in ["left_hand", "right_hand", "face"]) @@ -78,13 +78,13 @@ def extract_multi_landmarks(path, min_valid_frames=3): return None # Get all frames (might not start at 0) - frames = sorted(df["frame"].unique()) + frames = sorted(df.select("frame").unique().to_series().to_list()) if len(frames) < min_valid_frames: return None for frame in frames: - frame_df = df[df["frame"] == frame] + frame_df = df.filter(pl.col("frame") == frame) frame_points = np.full((TOTAL_POINTS_PER_FRAME, 3), np.nan, dtype=np.float32) pos = 0 @@ -93,48 +93,54 @@ def extract_multi_landmarks(path, min_valid_frames=3): face_valid = False # Left hand (need at least 10 valid points) - left = frame_df[frame_df["type"] == "left_hand"] - if len(left) > 0: + left = frame_df.filter(pl.col("type") == "left_hand") + if left.height > 0: valid_count = 0 for i in range(21): - row = left[left["landmark_index"] == i] - if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all(): - frame_points[pos] = row[['x', 'y', 'z']].values[0] - valid_count += 1 + row = left.filter(pl.col("landmark_index") == i) + if row.height > 0: + coords = row.select(["x", "y", "z"]).row(0) + if all(c is not None for c in coords): + frame_points[pos] = coords + valid_count += 1 pos += 1 - left_valid = (valid_count >= 10) # Relaxed from 15 + left_valid = (valid_count >= 10) else: pos += 21 # Right hand (need at least 10 valid points) - right = frame_df[frame_df["type"] == "right_hand"] - if len(right) > 0: + right = frame_df.filter(pl.col("type") == "right_hand") + if right.height > 0: valid_count = 0 for i in range(21): - row = right[right["landmark_index"] == i] - if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all(): - frame_points[pos] = row[['x', 'y', 'z']].values[0] - valid_count += 1 + row = right.filter(pl.col("landmark_index") == i) + if row.height > 0: + coords = row.select(["x", "y", "z"]).row(0) + if all(c is not None for c in coords): + frame_points[pos] = coords + valid_count += 1 pos += 1 - right_valid = (valid_count >= 10) # Relaxed from 15 + right_valid = (valid_count >= 10) else: pos += 21 # Face (need at least 30% of selected landmarks) - face = frame_df[frame_df["type"] == "face"] - if len(face) > 0: + face = frame_df.filter(pl.col("type") == "face") + if face.height > 0: valid_count = 0 for idx in IMPORTANT_FACE_INDICES: - row = face[face["landmark_index"] == idx] - if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all(): - frame_points[pos] = row[['x', 'y', 'z']].values[0] - valid_count += 1 + row = face.filter(pl.col("landmark_index") == idx) + if row.height > 0: + coords = row.select(["x", "y", "z"]).row(0) + if all(c is not None for c in coords): + frame_points[pos] = coords + valid_count += 1 pos += 1 - face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.3) # Relaxed from 0.5 + face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.3) # Accept frame if we have at least 20% valid data overall valid_ratio = 1 - np.isnan(frame_points).mean() - if valid_ratio >= 0.20: # Relaxed from 0.40 + if valid_ratio >= 0.20: frame_points = np.nan_to_num(frame_points, nan=0.0) seq.append(frame_points) left_valid_frames.append(left_valid) @@ -328,9 +334,9 @@ class ModalityAwareTransformer(nn.Module): def load_kaggle_asl_data(base_path): - """Load training metadata""" + """Load training metadata using Polars""" train_path = os.path.join(base_path, "train.csv") - train_df = pd.read_csv(train_path) + train_df = pl.read_csv(train_path) return train_df, None @@ -345,10 +351,10 @@ def main(): print("\nLoading metadata...") train_df, _ = load_kaggle_asl_data(base_path) - print(f"Total samples in train.csv: {len(train_df)}") + print(f"Total samples in train.csv: {train_df.height}") # Convert to simple tuples for multiprocessing compatibility - rows = [(row['path'], row['sign']) for _, row in train_df.iterrows()] + rows = [(row[0], row[1]) for row in train_df.select(["path", "sign"]).iter_rows()] print("\nProcessing sequences with BOTH hands + FACE (enhanced)...") print("This may take a few minutes...")