Battle against a True Hero pt 2
This commit is contained in:
@@ -4,25 +4,33 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import TensorDataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
# --- CONFIG ---
|
# --- CONFIG ---
|
||||||
BASE_PATH = "asl_kaggle"
|
BASE_PATH = "asl_kaggle"
|
||||||
|
CACHE_DIR = "asl_cache"
|
||||||
TARGET_FRAMES = 22
|
TARGET_FRAMES = 22
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
# --- DATA LOADING WITH RELATIVE FEATURES ---
|
# --- PREPROCESSING (RUN ONCE) ---
|
||||||
|
|
||||||
|
def process_single_file(args):
|
||||||
|
"""Process a single file - designed for multiprocessing"""
|
||||||
|
i, path, base_path, cache_dir = args
|
||||||
|
cache_path = os.path.join(cache_dir, f"sample_{i}.npy")
|
||||||
|
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
return # Skip if already cached
|
||||||
|
|
||||||
def load_file_to_memory(path, base_path=BASE_PATH):
|
|
||||||
try:
|
try:
|
||||||
parquet_path = os.path.join(base_path, path)
|
parquet_path = os.path.join(base_path, path)
|
||||||
df = pl.read_parquet(parquet_path)
|
df = pl.read_parquet(parquet_path)
|
||||||
|
|
||||||
# 1. Global Anchor (Nose)
|
# Global Anchor (Nose)
|
||||||
anchors = (
|
anchors = (
|
||||||
df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0))
|
df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0))
|
||||||
.select([
|
.select([
|
||||||
@@ -33,8 +41,7 @@ def load_file_to_memory(path, base_path=BASE_PATH):
|
|||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Local Anchors (Wrists)
|
# Local Anchors (Wrists)
|
||||||
# Left: 468, Right: 522
|
|
||||||
wrists = (
|
wrists = (
|
||||||
df.filter(pl.col("landmark_index").is_in([468, 522]))
|
df.filter(pl.col("landmark_index").is_in([468, 522]))
|
||||||
.select([
|
.select([
|
||||||
@@ -47,40 +54,83 @@ def load_file_to_memory(path, base_path=BASE_PATH):
|
|||||||
|
|
||||||
processed = df.join(anchors, on="frame", how="left")
|
processed = df.join(anchors, on="frame", how="left")
|
||||||
|
|
||||||
# Join wrist data to the main frame
|
|
||||||
# We use a left join on frame and landmark_index to align wrist coords with their rows
|
|
||||||
processed = (
|
processed = (
|
||||||
processed.join(wrists, on=["frame", "landmark_index"], how="left")
|
processed.join(wrists, on=["frame", "landmark_index"], how="left")
|
||||||
.with_columns([
|
.with_columns([
|
||||||
# Global (Nose-relative)
|
|
||||||
(pl.col("x") - pl.col("nx")).alias("x_g"),
|
(pl.col("x") - pl.col("nx")).alias("x_g"),
|
||||||
(pl.col("y") - pl.col("ny")).alias("y_g"),
|
(pl.col("y") - pl.col("ny")).alias("y_g"),
|
||||||
(pl.col("z") - pl.col("nz")).alias("z_g"),
|
(pl.col("z") - pl.col("nz")).alias("z_g"),
|
||||||
# Local (Wrist-relative - defaults to global if not a hand point)
|
|
||||||
(pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"),
|
(pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"),
|
||||||
(pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"),
|
(pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"),
|
||||||
])
|
])
|
||||||
.sort(["frame", "type", "landmark_index"])
|
.sort(["frame", "type", "landmark_index"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# We now have 5 channels: (x_g, y_g, z_g, x_l, y_l)
|
|
||||||
n_frames = processed["frame"].n_unique()
|
n_frames = processed["frame"].n_unique()
|
||||||
# Reshape to (Frames, 543 landmarks, 5 features)
|
|
||||||
tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5)
|
tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5)
|
||||||
|
|
||||||
# Temporal Resampling
|
# Temporal Resampling
|
||||||
indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int)
|
indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int)
|
||||||
return tensor[indices]
|
result = tensor[indices]
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
np.save(cache_path, result)
|
||||||
except Exception:
|
except Exception:
|
||||||
return np.zeros((TARGET_FRAMES, 543, 5))
|
# Save zero tensor for failed files
|
||||||
|
np.save(cache_path, np.zeros((TARGET_FRAMES, 543, 5)))
|
||||||
|
|
||||||
|
|
||||||
# --- DUAL-STREAM MODEL ---
|
def preprocess_and_cache(paths, base_path=BASE_PATH, cache_dir=CACHE_DIR):
|
||||||
|
"""Preprocess all files in parallel and save as numpy arrays"""
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Check if already cached
|
||||||
|
all_cached = all(os.path.exists(os.path.join(cache_dir, f"sample_{i}.npy")) for i in range(len(paths)))
|
||||||
|
if all_cached:
|
||||||
|
print("All files already cached, skipping preprocessing...")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Preprocessing {len(paths)} files in parallel...")
|
||||||
|
|
||||||
|
# Create arguments for each file
|
||||||
|
args_list = [(i, path, base_path, cache_dir) for i, path in enumerate(paths)]
|
||||||
|
|
||||||
|
# Process in parallel
|
||||||
|
with ProcessPoolExecutor() as executor:
|
||||||
|
list(tqdm(executor.map(process_single_file, args_list), total=len(args_list)))
|
||||||
|
|
||||||
|
print("Preprocessing complete!")
|
||||||
|
|
||||||
|
|
||||||
|
# --- FAST DATASET (LOADS FROM CACHE) ---
|
||||||
|
|
||||||
|
class CachedASLDataset(Dataset):
|
||||||
|
"""Fast dataset that loads from preprocessed numpy files"""
|
||||||
|
|
||||||
|
def __init__(self, indices, labels, cache_dir=CACHE_DIR):
|
||||||
|
self.indices = indices
|
||||||
|
self.labels = labels
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample_idx = self.indices[idx]
|
||||||
|
cache_path = os.path.join(self.cache_dir, f"sample_{sample_idx}.npy")
|
||||||
|
|
||||||
|
# Fast numpy load
|
||||||
|
data = np.load(cache_path)
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
return torch.tensor(data, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
# --- MODEL ---
|
||||||
|
|
||||||
class ASLClassifier(nn.Module):
|
class ASLClassifier(nn.Module):
|
||||||
def __init__(self, num_classes):
|
def __init__(self, num_classes):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 543 landmarks * 5 features per landmark = 2715
|
|
||||||
self.feat_dim = 543 * 5
|
self.feat_dim = 543 * 5
|
||||||
|
|
||||||
self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1)
|
||||||
@@ -99,10 +149,8 @@ class ASLClassifier(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x shape: (Batch, 22, 543, 5)
|
|
||||||
b, t, l, f = x.shape
|
b, t, l, f = x.shape
|
||||||
# Flatten landmarks and features into one vector, then transpose for Conv1d
|
x = x.view(b, t, -1).transpose(1, 2)
|
||||||
x = x.view(b, t, -1).transpose(1, 2) # (Batch, 2715, 22)
|
|
||||||
|
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
x = self.pool(x)
|
x = self.pool(x)
|
||||||
@@ -110,7 +158,6 @@ class ASLClassifier(nn.Module):
|
|||||||
x = F.relu(self.bn2(self.conv2(x)))
|
x = F.relu(self.bn2(self.conv2(x)))
|
||||||
x = self.pool(x)
|
x = self.pool(x)
|
||||||
|
|
||||||
# Global Average Pool across the time dimension
|
|
||||||
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
|
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
|
||||||
|
|
||||||
return self.fc(self.dropout(x))
|
return self.fc(self.dropout(x))
|
||||||
@@ -119,38 +166,51 @@ class ASLClassifier(nn.Module):
|
|||||||
# --- EXECUTION ---
|
# --- EXECUTION ---
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 1. Setup Data
|
# 1. Setup Metadata
|
||||||
metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv"))
|
metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv"))
|
||||||
unique_signs = sorted(metadata["sign"].unique().to_list())
|
unique_signs = sorted(metadata["sign"].unique().to_list())
|
||||||
sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)}
|
sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)}
|
||||||
labels = [sign_to_idx[s] for s in metadata["sign"].to_list()]
|
|
||||||
|
|
||||||
paths = metadata["path"].to_list()
|
paths = metadata["path"].to_list()
|
||||||
|
labels = [sign_to_idx[s] for s in metadata["sign"].to_list()]
|
||||||
|
|
||||||
# 2. Load to RAM (Parallelized)
|
# 2. Preprocess and cache (parallelized, only runs if cache doesn't exist)
|
||||||
print(f"Loading {len(paths)} files into RAM with 5-channel features...")
|
preprocess_and_cache(paths)
|
||||||
with ProcessPoolExecutor() as executor:
|
|
||||||
data_list = list(tqdm(executor.map(load_file_to_memory, paths), total=len(paths)))
|
|
||||||
|
|
||||||
X = torch.tensor(np.array(data_list), dtype=torch.float32)
|
# 3. Create index mapping for train/val split
|
||||||
y = torch.tensor(labels, dtype=torch.long)
|
all_indices = list(range(len(paths)))
|
||||||
|
train_indices, val_indices, train_labels, val_labels = train_test_split(
|
||||||
|
all_indices, labels, test_size=0.1, stratify=labels, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Split
|
# 4. Create datasets from cached files
|
||||||
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42)
|
train_dataset = CachedASLDataset(train_indices, train_labels)
|
||||||
|
val_dataset = CachedASLDataset(val_indices, val_labels)
|
||||||
|
|
||||||
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
|
# Increase batch size and workers since loading is now fast
|
||||||
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64)
|
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=4, pin_memory=True)
|
||||||
|
|
||||||
# 4. Train
|
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
||||||
|
|
||||||
|
# 5. Train
|
||||||
model = ASLClassifier(len(unique_signs)).to(device)
|
model = ASLClassifier(len(unique_signs)).to(device)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Helps prevent over-confidence
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
|
||||||
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
|
best_acc = 0.0
|
||||||
print(f"Starting training on {device}...")
|
print(f"Starting training on {device}...")
|
||||||
|
|
||||||
for epoch in range(25):
|
for epoch in range(25):
|
||||||
|
# Training
|
||||||
model.train()
|
model.train()
|
||||||
train_loss = 0
|
train_loss = 0
|
||||||
for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
|
train_correct = 0
|
||||||
|
train_total = 0
|
||||||
|
|
||||||
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/25 [Train]")
|
||||||
|
for batch_x, batch_y in pbar:
|
||||||
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
|
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -158,19 +218,47 @@ if __name__ == "__main__":
|
|||||||
loss = criterion(output, batch_y)
|
loss = criterion(output, batch_y)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
train_loss += loss.item()
|
train_loss += loss.item()
|
||||||
|
_, predicted = torch.max(output, 1)
|
||||||
|
train_total += batch_y.size(0)
|
||||||
|
train_correct += (predicted == batch_y).sum().item()
|
||||||
|
|
||||||
|
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100 * train_correct / train_total:.1f}%'})
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
model.eval()
|
model.eval()
|
||||||
correct, total = 0, 0
|
val_correct, val_total = 0, 0
|
||||||
|
val_loss = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for vx, vy in val_loader:
|
for vx, vy in tqdm(val_loader, desc=f"Epoch {epoch + 1}/25 [Val]"):
|
||||||
vx, vy = vx.to(device), vy.to(device)
|
vx, vy = vx.to(device), vy.to(device)
|
||||||
pred = model(vx).argmax(1)
|
output = model(vx)
|
||||||
correct += (pred == vy).sum().item()
|
val_loss += criterion(output, vy).item()
|
||||||
total += vy.size(0)
|
pred = output.argmax(1)
|
||||||
|
val_correct += (pred == vy).sum().item()
|
||||||
|
val_total += vy.size(0)
|
||||||
|
|
||||||
print(f"Epoch {epoch + 1} | Loss: {train_loss / len(train_loader):.4f} | Val Acc: {100 * correct / total:.2f}%")
|
avg_train_loss = train_loss / len(train_loader)
|
||||||
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
|
train_acc = 100 * train_correct / train_total
|
||||||
|
val_acc = 100 * val_correct / val_total
|
||||||
|
|
||||||
|
scheduler.step(avg_val_loss)
|
||||||
|
|
||||||
|
print(f"\nEpoch {epoch + 1}/25:")
|
||||||
|
print(f" Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
|
||||||
|
print(f" Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%")
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if val_acc > best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
torch.save(model.state_dict(), "best_asl_model.pth")
|
||||||
|
print(f" ✓ Best model saved! (Val Acc: {val_acc:.2f}%)\n")
|
||||||
|
|
||||||
|
# Checkpoint every 5 epochs
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
torch.save(model.state_dict(), f"asl_model_v2_e{epoch + 1}.pth")
|
torch.save(model.state_dict(), f"asl_model_e{epoch + 1}.pth")
|
||||||
|
|
||||||
|
print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")
|
||||||
Reference in New Issue
Block a user