Battle against a True Hero pt 2
This commit is contained in:
@@ -4,166 +4,173 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# --- CONFIGURATION ---
|
||||
# --- CONFIG ---
|
||||
BASE_PATH = "asl_kaggle"
|
||||
TARGET_FRAMES = 22
|
||||
LIPS = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95]
|
||||
HANDS = list(range(468, 543))
|
||||
SELECTED_INDICES = LIPS + HANDS
|
||||
NUM_FEATS = len(SELECTED_INDICES) * 3
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
# --- DATASET ENGINE ---
|
||||
# --- DATA LOADING WITH RELATIVE FEATURES ---
|
||||
|
||||
def load_and_preprocess(path, base_path=BASE_PATH, target_frames=TARGET_FRAMES):
|
||||
def load_file_to_memory(path, base_path=BASE_PATH):
|
||||
try:
|
||||
parquet_path = os.path.join(base_path, path)
|
||||
df = pl.read_parquet(parquet_path)
|
||||
|
||||
# 1. Spatial Normalization
|
||||
# 1. Global Anchor (Nose)
|
||||
anchors = (
|
||||
df.filter((pl.col("type") == "face") & (pl.col("landmark_index") == 0))
|
||||
.select([pl.col("frame"), pl.col("x").alias("nx"), pl.col("y").alias("ny"), pl.col("z").alias("nz")])
|
||||
.select([
|
||||
pl.col("frame"),
|
||||
pl.col("x").alias("nx"),
|
||||
pl.col("y").alias("ny"),
|
||||
pl.col("z").alias("nz")
|
||||
])
|
||||
)
|
||||
|
||||
# 2. Local Anchors (Wrists)
|
||||
# Left: 468, Right: 522
|
||||
wrists = (
|
||||
df.filter(pl.col("landmark_index").is_in([468, 522]))
|
||||
.select([
|
||||
pl.col("frame"),
|
||||
pl.col("landmark_index"),
|
||||
pl.col("x").alias("wx"),
|
||||
pl.col("y").alias("wy")
|
||||
])
|
||||
)
|
||||
|
||||
processed = df.join(anchors, on="frame", how="left")
|
||||
|
||||
# Join wrist data to the main frame
|
||||
# We use a left join on frame and landmark_index to align wrist coords with their rows
|
||||
processed = (
|
||||
df.join(anchors, on="frame", how="left")
|
||||
processed.join(wrists, on=["frame", "landmark_index"], how="left")
|
||||
.with_columns([
|
||||
(pl.col("x") - pl.col("nx")).fill_null(0.0),
|
||||
(pl.col("y") - pl.col("ny")).fill_null(0.0),
|
||||
(pl.col("z") - pl.col("nz")).fill_null(0.0),
|
||||
# Global (Nose-relative)
|
||||
(pl.col("x") - pl.col("nx")).alias("x_g"),
|
||||
(pl.col("y") - pl.col("ny")).alias("y_g"),
|
||||
(pl.col("z") - pl.col("nz")).alias("z_g"),
|
||||
# Local (Wrist-relative - defaults to global if not a hand point)
|
||||
(pl.col("x") - pl.col("wx")).fill_null(pl.col("x") - pl.col("nx")).alias("x_l"),
|
||||
(pl.col("y") - pl.col("wy")).fill_null(pl.col("y") - pl.col("ny")).alias("y_l"),
|
||||
])
|
||||
.sort(["frame", "type", "landmark_index"])
|
||||
)
|
||||
|
||||
# 2. Slice and Reshape
|
||||
raw_tensor = processed.select(["x", "y", "z"]).to_numpy().reshape(-1, 543, 3)
|
||||
reduced_tensor = raw_tensor[:, SELECTED_INDICES, :]
|
||||
# We now have 5 channels: (x_g, y_g, z_g, x_l, y_l)
|
||||
n_frames = processed["frame"].n_unique()
|
||||
# Reshape to (Frames, 543 landmarks, 5 features)
|
||||
tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5)
|
||||
|
||||
# 3. Temporal Resampling
|
||||
curr_len = reduced_tensor.shape[0]
|
||||
indices = np.linspace(0, curr_len - 1, num=target_frames).round().astype(int)
|
||||
return reduced_tensor[indices]
|
||||
# Temporal Resampling
|
||||
indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int)
|
||||
return tensor[indices]
|
||||
except Exception:
|
||||
return np.zeros((TARGET_FRAMES, 543, 5))
|
||||
|
||||
|
||||
class ASLDataset(Dataset):
|
||||
def __init__(self, paths, labels):
|
||||
self.paths = paths
|
||||
self.labels = labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
x = load_and_preprocess(self.paths[idx])
|
||||
y = self.labels[idx]
|
||||
return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)
|
||||
except Exception as e:
|
||||
# Return a zero tensor if a file is corrupted to prevent crash
|
||||
return torch.zeros((TARGET_FRAMES, len(SELECTED_INDICES), 3)), torch.tensor(self.labels[idx],
|
||||
dtype=torch.long)
|
||||
|
||||
|
||||
# --- MODEL ---
|
||||
# --- DUAL-STREAM MODEL ---
|
||||
|
||||
class ASLClassifier(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(NUM_FEATS, 256, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm1d(256)
|
||||
self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
|
||||
# 543 landmarks * 5 features per landmark = 2715
|
||||
self.feat_dim = 543 * 5
|
||||
|
||||
self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm1d(512)
|
||||
self.conv2 = nn.Conv1d(512, 512, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm1d(512)
|
||||
|
||||
self.pool = nn.MaxPool1d(2)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.fc = nn.Linear(512, num_classes)
|
||||
self.dropout = nn.Dropout(0.4)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(512, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(1024, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (Batch, 22, 96, 3)
|
||||
b, t, l, c = x.shape
|
||||
x = x.view(b, t, -1).transpose(1, 2) # (Batch, Features, Time)
|
||||
# x shape: (Batch, 22, 543, 5)
|
||||
b, t, l, f = x.shape
|
||||
# Flatten landmarks and features into one vector, then transpose for Conv1d
|
||||
x = x.view(b, t, -1).transpose(1, 2) # (Batch, 2715, 22)
|
||||
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = self.pool(x)
|
||||
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = self.pool(x)
|
||||
|
||||
# Global Average Pool across the time dimension
|
||||
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
|
||||
x = self.dropout(x)
|
||||
return self.fc(x)
|
||||
|
||||
return self.fc(self.dropout(x))
|
||||
|
||||
|
||||
# --- TRAINING LOOP ---
|
||||
# --- EXECUTION ---
|
||||
|
||||
def run_training():
|
||||
# 1. Prepare Metadata
|
||||
train_df = pl.read_csv(os.path.join(BASE_PATH, "train.csv"))
|
||||
unique_signs = sorted(train_df["sign"].unique().to_list())
|
||||
if __name__ == "__main__":
|
||||
# 1. Setup Data
|
||||
metadata = pl.read_csv(os.path.join(BASE_PATH, "train.csv"))
|
||||
unique_signs = sorted(metadata["sign"].unique().to_list())
|
||||
sign_to_idx = {sign: i for i, sign in enumerate(unique_signs)}
|
||||
labels = [sign_to_idx[s] for s in metadata["sign"].to_list()]
|
||||
|
||||
paths = train_df["path"].to_list()
|
||||
labels = [sign_to_idx[s] for s in train_df["sign"].to_list()]
|
||||
paths = metadata["path"].to_list()
|
||||
|
||||
# 2. Split
|
||||
p_train, p_val, l_train, l_val = train_test_split(paths, labels, test_size=0.15, stratify=labels)
|
||||
# 2. Load to RAM (Parallelized)
|
||||
print(f"Loading {len(paths)} files into RAM with 5-channel features...")
|
||||
with ProcessPoolExecutor() as executor:
|
||||
data_list = list(tqdm(executor.map(load_file_to_memory, paths), total=len(paths)))
|
||||
|
||||
# 3. Loaders
|
||||
train_ds = ASLDataset(p_train, l_train)
|
||||
val_ds = ASLDataset(p_val, l_val)
|
||||
X = torch.tensor(np.array(data_list), dtype=torch.float32)
|
||||
y = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
# num_workers=4 allows the CPU to preprocess the next batch while GPU trains
|
||||
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
|
||||
# 3. Split
|
||||
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42)
|
||||
|
||||
# 4. Init Model & Optim
|
||||
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
|
||||
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64)
|
||||
|
||||
# 4. Train
|
||||
model = ASLClassifier(len(unique_signs)).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
scaler = torch.amp.GradScaler(enabled=(device.type == 'cuda'))
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Helps prevent over-confidence
|
||||
|
||||
# 5. Loop
|
||||
for epoch in range(30):
|
||||
print(f"Starting training on {device}...")
|
||||
for epoch in range(25):
|
||||
model.train()
|
||||
t_correct, t_total = 0, 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
|
||||
for x, y in pbar:
|
||||
x, y = x.to(device), y.to(device)
|
||||
train_loss = 0
|
||||
for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
|
||||
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Use Mixed Precision for speed
|
||||
with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
|
||||
outputs = model(x)
|
||||
loss = criterion(outputs, y)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
_, pred = torch.max(outputs, 1)
|
||||
t_total += y.size(0)
|
||||
t_correct += (pred == y).sum().item()
|
||||
pbar.set_postfix(acc=f"{(t_correct / t_total) * 100:.1f}%")
|
||||
output = model(batch_x)
|
||||
loss = criterion(output, batch_y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
v_correct, v_total = 0, 0
|
||||
correct, total = 0, 0
|
||||
with torch.no_grad():
|
||||
for x, y in val_loader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
outputs = model(x)
|
||||
_, pred = torch.max(outputs, 1)
|
||||
v_total += y.size(0)
|
||||
v_correct += (pred == y).sum().item()
|
||||
for vx, vy in val_loader:
|
||||
vx, vy = vx.to(device), vy.to(device)
|
||||
pred = model(vx).argmax(1)
|
||||
correct += (pred == vy).sum().item()
|
||||
total += vy.size(0)
|
||||
|
||||
print(f"Validation Accuracy: {(v_correct / v_total) * 100:.2f}%")
|
||||
torch.save(model.state_dict(), f"asl_model_epoch_{epoch}.pth")
|
||||
print(f"Epoch {epoch + 1} | Loss: {train_loss / len(train_loader):.4f} | Val Acc: {100 * correct / total:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_training()
|
||||
if (epoch + 1) % 5 == 0:
|
||||
torch.save(model.state_dict(), f"asl_model_v2_e{epoch + 1}.pth")
|
||||
Reference in New Issue
Block a user