soy cooked pt 2

This commit is contained in:
2026-01-19 23:21:35 -06:00
parent 16a4f1d2b9
commit 2cbd178ba8

View File

@@ -2,7 +2,7 @@ import os
import json
import math
import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -56,7 +56,7 @@ TOTAL_POINTS_PER_FRAME = NUM_HAND_POINTS + NUM_FACE_POINTS
# ===============================
# ENHANCED DATA EXTRACTION (FIXED)
# ENHANCED DATA EXTRACTION (POLARS)
# ===============================
def extract_multi_landmarks(path, min_valid_frames=3):
"""
@@ -64,13 +64,13 @@ def extract_multi_landmarks(path, min_valid_frames=3):
Returns: dict with 'landmarks', 'left_hand_valid', 'right_hand_valid', 'face_valid'
"""
try:
df = pd.read_parquet(path)
df = pl.read_parquet(path)
seq = []
left_valid_frames = []
right_valid_frames = []
face_valid_frames = []
all_types = df["type"].unique()
all_types = df.select("type").unique().to_series().to_list()
# Check if we have at least one of the required types
has_data = any(t in all_types for t in ["left_hand", "right_hand", "face"])
@@ -78,13 +78,13 @@ def extract_multi_landmarks(path, min_valid_frames=3):
return None
# Get all frames (might not start at 0)
frames = sorted(df["frame"].unique())
frames = sorted(df.select("frame").unique().to_series().to_list())
if len(frames) < min_valid_frames:
return None
for frame in frames:
frame_df = df[df["frame"] == frame]
frame_df = df.filter(pl.col("frame") == frame)
frame_points = np.full((TOTAL_POINTS_PER_FRAME, 3), np.nan, dtype=np.float32)
pos = 0
@@ -93,48 +93,54 @@ def extract_multi_landmarks(path, min_valid_frames=3):
face_valid = False
# Left hand (need at least 10 valid points)
left = frame_df[frame_df["type"] == "left_hand"]
if len(left) > 0:
left = frame_df.filter(pl.col("type") == "left_hand")
if left.height > 0:
valid_count = 0
for i in range(21):
row = left[left["landmark_index"] == i]
if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all():
frame_points[pos] = row[['x', 'y', 'z']].values[0]
valid_count += 1
row = left.filter(pl.col("landmark_index") == i)
if row.height > 0:
coords = row.select(["x", "y", "z"]).row(0)
if all(c is not None for c in coords):
frame_points[pos] = coords
valid_count += 1
pos += 1
left_valid = (valid_count >= 10) # Relaxed from 15
left_valid = (valid_count >= 10)
else:
pos += 21
# Right hand (need at least 10 valid points)
right = frame_df[frame_df["type"] == "right_hand"]
if len(right) > 0:
right = frame_df.filter(pl.col("type") == "right_hand")
if right.height > 0:
valid_count = 0
for i in range(21):
row = right[right["landmark_index"] == i]
if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all():
frame_points[pos] = row[['x', 'y', 'z']].values[0]
valid_count += 1
row = right.filter(pl.col("landmark_index") == i)
if row.height > 0:
coords = row.select(["x", "y", "z"]).row(0)
if all(c is not None for c in coords):
frame_points[pos] = coords
valid_count += 1
pos += 1
right_valid = (valid_count >= 10) # Relaxed from 15
right_valid = (valid_count >= 10)
else:
pos += 21
# Face (need at least 30% of selected landmarks)
face = frame_df[frame_df["type"] == "face"]
if len(face) > 0:
face = frame_df.filter(pl.col("type") == "face")
if face.height > 0:
valid_count = 0
for idx in IMPORTANT_FACE_INDICES:
row = face[face["landmark_index"] == idx]
if len(row) > 0 and row[['x', 'y', 'z']].notna().all().all():
frame_points[pos] = row[['x', 'y', 'z']].values[0]
valid_count += 1
row = face.filter(pl.col("landmark_index") == idx)
if row.height > 0:
coords = row.select(["x", "y", "z"]).row(0)
if all(c is not None for c in coords):
frame_points[pos] = coords
valid_count += 1
pos += 1
face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.3) # Relaxed from 0.5
face_valid = (valid_count >= len(IMPORTANT_FACE_INDICES) * 0.3)
# Accept frame if we have at least 20% valid data overall
valid_ratio = 1 - np.isnan(frame_points).mean()
if valid_ratio >= 0.20: # Relaxed from 0.40
if valid_ratio >= 0.20:
frame_points = np.nan_to_num(frame_points, nan=0.0)
seq.append(frame_points)
left_valid_frames.append(left_valid)
@@ -328,9 +334,9 @@ class ModalityAwareTransformer(nn.Module):
def load_kaggle_asl_data(base_path):
"""Load training metadata"""
"""Load training metadata using Polars"""
train_path = os.path.join(base_path, "train.csv")
train_df = pd.read_csv(train_path)
train_df = pl.read_csv(train_path)
return train_df, None
@@ -345,10 +351,10 @@ def main():
print("\nLoading metadata...")
train_df, _ = load_kaggle_asl_data(base_path)
print(f"Total samples in train.csv: {len(train_df)}")
print(f"Total samples in train.csv: {train_df.height}")
# Convert to simple tuples for multiprocessing compatibility
rows = [(row['path'], row['sign']) for _, row in train_df.iterrows()]
rows = [(row[0], row[1]) for row in train_df.select(["path", "sign"]).iter_rows()]
print("\nProcessing sequences with BOTH hands + FACE (enhanced)...")
print("This may take a few minutes...")