lets try this architecture
This commit is contained in:
@@ -15,13 +15,21 @@ CACHE_DIR = "asl_cache"
|
|||||||
TARGET_FRAMES = 22
|
TARGET_FRAMES = 22
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Landmark indices for each body part (based on MediaPipe Holistic)
|
||||||
|
LANDMARK_RANGES = {
|
||||||
|
'left_hand': list(range(468, 489)), # 21 landmarks
|
||||||
|
'right_hand': list(range(522, 543)), # 21 landmarks
|
||||||
|
'pose': list(range(489, 522)), # 33 landmarks
|
||||||
|
'face': list(range(0, 468)) # 468 landmarks
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# --- PREPROCESSING (RUN ONCE) ---
|
# --- PREPROCESSING (RUN ONCE) ---
|
||||||
|
|
||||||
def process_single_file(args):
|
def process_single_file(args):
|
||||||
"""Process a single file - designed for multiprocessing"""
|
"""Process a single file - designed for multiprocessing"""
|
||||||
i, path, base_path, cache_dir = args
|
i, path, base_path, cache_dir = args
|
||||||
cache_path = os.path.join(cache_dir, f"sample_{i}.npy")
|
cache_path = os.path.join(cache_dir, f"sample_{i}.npz")
|
||||||
|
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
return # Skip if already cached
|
return # Skip if already cached
|
||||||
@@ -41,7 +49,7 @@ def process_single_file(args):
|
|||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Local Anchors (Wrists)
|
# Local Anchors (Wrists - landmark_index 468 and 522)
|
||||||
wrists = (
|
wrists = (
|
||||||
df.filter(pl.col("landmark_index").is_in([468, 522]))
|
df.filter(pl.col("landmark_index").is_in([468, 522]))
|
||||||
.select([
|
.select([
|
||||||
@@ -67,17 +75,37 @@ def process_single_file(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
n_frames = processed["frame"].n_unique()
|
n_frames = processed["frame"].n_unique()
|
||||||
tensor = processed.select(["x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy().reshape(n_frames, 543, 5)
|
|
||||||
|
# Full tensor for indexing
|
||||||
|
full_tensor = processed.select(["landmark_index", "x_g", "y_g", "z_g", "x_l", "y_l"]).to_numpy()
|
||||||
|
full_tensor = full_tensor.reshape(n_frames, 543, 6) # 6 = 1 (index) + 5 (features)
|
||||||
|
|
||||||
|
# Extract landmark_index and features separately
|
||||||
|
full_data = full_tensor[:, :, 1:] # Remove landmark_index column (features only)
|
||||||
|
|
||||||
# Temporal Resampling
|
# Temporal Resampling
|
||||||
indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int)
|
indices = np.linspace(0, n_frames - 1, num=TARGET_FRAMES).round().astype(int)
|
||||||
result = tensor[indices]
|
resampled = full_data[indices] # (22, 543, 5)
|
||||||
|
|
||||||
# Save to cache
|
# Split by body part
|
||||||
np.save(cache_path, result)
|
left_hand = resampled[:, LANDMARK_RANGES['left_hand'], :] # (22, 21, 5)
|
||||||
except Exception:
|
right_hand = resampled[:, LANDMARK_RANGES['right_hand'], :] # (22, 21, 5)
|
||||||
# Save zero tensor for failed files
|
pose = resampled[:, LANDMARK_RANGES['pose'], :] # (22, 33, 5)
|
||||||
np.save(cache_path, np.zeros((TARGET_FRAMES, 543, 5)))
|
face = resampled[:, LANDMARK_RANGES['face'], :] # (22, 468, 5)
|
||||||
|
|
||||||
|
# Save as compressed npz
|
||||||
|
np.savez_compressed(cache_path,
|
||||||
|
left_hand=left_hand.astype(np.float32),
|
||||||
|
right_hand=right_hand.astype(np.float32),
|
||||||
|
pose=pose.astype(np.float32),
|
||||||
|
face=face.astype(np.float32))
|
||||||
|
except Exception as e:
|
||||||
|
# Save zero tensors for failed files
|
||||||
|
np.savez_compressed(cache_path,
|
||||||
|
left_hand=np.zeros((TARGET_FRAMES, 21, 5), dtype=np.float32),
|
||||||
|
right_hand=np.zeros((TARGET_FRAMES, 21, 5), dtype=np.float32),
|
||||||
|
pose=np.zeros((TARGET_FRAMES, 33, 5), dtype=np.float32),
|
||||||
|
face=np.zeros((TARGET_FRAMES, 468, 5), dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
def preprocess_and_cache(paths, base_path=BASE_PATH, cache_dir=CACHE_DIR):
|
def preprocess_and_cache(paths, base_path=BASE_PATH, cache_dir=CACHE_DIR):
|
||||||
@@ -85,7 +113,7 @@ def preprocess_and_cache(paths, base_path=BASE_PATH, cache_dir=CACHE_DIR):
|
|||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
# Check if already cached
|
# Check if already cached
|
||||||
all_cached = all(os.path.exists(os.path.join(cache_dir, f"sample_{i}.npy")) for i in range(len(paths)))
|
all_cached = all(os.path.exists(os.path.join(cache_dir, f"sample_{i}.npz")) for i in range(len(paths)))
|
||||||
if all_cached:
|
if all_cached:
|
||||||
print("All files already cached, skipping preprocessing...")
|
print("All files already cached, skipping preprocessing...")
|
||||||
return
|
return
|
||||||
@@ -117,50 +145,122 @@ class CachedASLDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
sample_idx = self.indices[idx]
|
sample_idx = self.indices[idx]
|
||||||
cache_path = os.path.join(self.cache_dir, f"sample_{sample_idx}.npy")
|
cache_path = os.path.join(self.cache_dir, f"sample_{sample_idx}.npz")
|
||||||
|
|
||||||
# Fast numpy load
|
# Fast numpy load
|
||||||
data = np.load(cache_path)
|
data = np.load(cache_path)
|
||||||
label = self.labels[idx]
|
|
||||||
|
|
||||||
return torch.tensor(data, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
|
# Load each body part
|
||||||
|
left_hand = torch.tensor(data['left_hand'], dtype=torch.float32)
|
||||||
|
right_hand = torch.tensor(data['right_hand'], dtype=torch.float32)
|
||||||
|
pose = torch.tensor(data['pose'], dtype=torch.float32)
|
||||||
|
face = torch.tensor(data['face'], dtype=torch.float32)
|
||||||
|
|
||||||
|
label = torch.tensor(self.labels[idx], dtype=torch.long)
|
||||||
|
|
||||||
|
return (left_hand, right_hand, pose, face), label
|
||||||
|
|
||||||
|
|
||||||
# --- MODEL ---
|
# --- IMPROVED MODEL WITH SEPARATE STREAMS ---
|
||||||
|
|
||||||
class ASLClassifier(nn.Module):
|
class ASLClassifierSeparateStreams(nn.Module):
|
||||||
def __init__(self, num_classes):
|
def __init__(self, num_classes, dropout_rate=0.3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.feat_dim = 543 * 5
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv1d(self.feat_dim, 512, kernel_size=3, padding=1)
|
# Feature dimensions for each body part
|
||||||
self.bn1 = nn.BatchNorm1d(512)
|
self.left_hand_dim = 21 * 5 # 21 landmarks * 5 features
|
||||||
self.conv2 = nn.Conv1d(512, 512, kernel_size=3, padding=1)
|
self.right_hand_dim = 21 * 5 # 21 landmarks * 5 features
|
||||||
self.bn2 = nn.BatchNorm1d(512)
|
self.pose_dim = 33 * 5 # 33 landmarks * 5 features
|
||||||
|
self.face_dim = 468 * 5 # 468 landmarks * 5 features
|
||||||
|
|
||||||
self.pool = nn.MaxPool1d(2)
|
# Separate convolutional streams for each body part
|
||||||
self.dropout = nn.Dropout(0.4)
|
self.left_hand_stream = self._make_conv_stream(self.left_hand_dim, 128, dropout_rate)
|
||||||
|
self.right_hand_stream = self._make_conv_stream(self.right_hand_dim, 128, dropout_rate)
|
||||||
|
self.pose_stream = self._make_conv_stream(self.pose_dim, 128, dropout_rate)
|
||||||
|
self.face_stream = self._make_conv_stream(self.face_dim, 256, dropout_rate)
|
||||||
|
|
||||||
self.fc = nn.Sequential(
|
# Combined features dimension
|
||||||
nn.Linear(512, 1024),
|
combined_dim = 128 + 128 + 128 + 256 # 640
|
||||||
|
|
||||||
|
# Bidirectional LSTM for temporal modeling
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size=combined_dim,
|
||||||
|
hidden_size=256,
|
||||||
|
num_layers=2,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
dropout=dropout_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention mechanism
|
||||||
|
self.attention = nn.Sequential(
|
||||||
|
nn.Linear(512, 128), # 512 from bidirectional LSTM (256*2)
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Classification head
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512, 512),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.2),
|
nn.Dropout(dropout_rate),
|
||||||
nn.Linear(1024, num_classes)
|
nn.Linear(512, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
nn.Linear(256, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_conv_stream(self, input_dim, output_dim, dropout_rate):
|
||||||
|
"""Create a convolutional stream for processing a body part"""
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv1d(input_dim, output_dim * 2, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(output_dim * 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout_rate * 0.5),
|
||||||
|
nn.Conv1d(output_dim * 2, output_dim, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(output_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout_rate * 0.5)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, t, l, f = x.shape
|
# x is a tuple: (left_hand, right_hand, pose, face)
|
||||||
x = x.view(b, t, -1).transpose(1, 2)
|
left_hand, right_hand, pose, face = x
|
||||||
|
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
b = left_hand.shape[0] # batch size
|
||||||
x = self.pool(x)
|
|
||||||
|
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
# Flatten landmarks and features for each body part
|
||||||
x = self.pool(x)
|
# Shape: (batch, time, landmarks, features) -> (batch, landmarks*features, time)
|
||||||
|
left_hand = left_hand.view(b, TARGET_FRAMES, -1).transpose(1, 2)
|
||||||
|
right_hand = right_hand.view(b, TARGET_FRAMES, -1).transpose(1, 2)
|
||||||
|
pose = pose.view(b, TARGET_FRAMES, -1).transpose(1, 2)
|
||||||
|
face = face.view(b, TARGET_FRAMES, -1).transpose(1, 2)
|
||||||
|
|
||||||
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
|
# Process each body part through its stream
|
||||||
|
left_hand_feat = self.left_hand_stream(left_hand) # (batch, 128, time)
|
||||||
|
right_hand_feat = self.right_hand_stream(right_hand) # (batch, 128, time)
|
||||||
|
pose_feat = self.pose_stream(pose) # (batch, 128, time)
|
||||||
|
face_feat = self.face_stream(face) # (batch, 256, time)
|
||||||
|
|
||||||
return self.fc(self.dropout(x))
|
# Transpose back: (batch, features, time) -> (batch, time, features)
|
||||||
|
left_hand_feat = left_hand_feat.transpose(1, 2)
|
||||||
|
right_hand_feat = right_hand_feat.transpose(1, 2)
|
||||||
|
pose_feat = pose_feat.transpose(1, 2)
|
||||||
|
face_feat = face_feat.transpose(1, 2)
|
||||||
|
|
||||||
|
# Concatenate all features
|
||||||
|
combined = torch.cat([left_hand_feat, right_hand_feat, pose_feat, face_feat], dim=2)
|
||||||
|
# combined shape: (batch, time, 640)
|
||||||
|
|
||||||
|
# LSTM processing
|
||||||
|
lstm_out, _ = self.lstm(combined) # (batch, time, 512)
|
||||||
|
|
||||||
|
# Attention mechanism
|
||||||
|
attention_weights = F.softmax(self.attention(lstm_out), dim=1) # (batch, time, 1)
|
||||||
|
attended = torch.sum(attention_weights * lstm_out, dim=1) # (batch, 512)
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
return self.classifier(attended)
|
||||||
|
|
||||||
|
|
||||||
# --- EXECUTION ---
|
# --- EXECUTION ---
|
||||||
@@ -187,20 +287,21 @@ if __name__ == "__main__":
|
|||||||
train_dataset = CachedASLDataset(train_indices, train_labels)
|
train_dataset = CachedASLDataset(train_indices, train_labels)
|
||||||
val_dataset = CachedASLDataset(val_indices, val_labels)
|
val_dataset = CachedASLDataset(val_indices, val_labels)
|
||||||
|
|
||||||
# Increase batch size and workers since loading is now fast
|
# Adjust batch size based on GPU memory (separate streams use more memory)
|
||||||
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
|
||||||
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=4, pin_memory=True)
|
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=4, pin_memory=True)
|
||||||
|
|
||||||
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
||||||
|
|
||||||
# 5. Train
|
# 5. Train
|
||||||
model = ASLClassifier(len(unique_signs)).to(device)
|
model = ASLClassifierSeparateStreams(len(unique_signs)).to(device)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
|
||||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
best_acc = 0.0
|
best_acc = 0.0
|
||||||
print(f"Starting training on {device}...")
|
print(f"Starting training on {device}...")
|
||||||
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
for epoch in range(25):
|
for epoch in range(25):
|
||||||
# Training
|
# Training
|
||||||
@@ -210,11 +311,13 @@ if __name__ == "__main__":
|
|||||||
train_total = 0
|
train_total = 0
|
||||||
|
|
||||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/25 [Train]")
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/25 [Train]")
|
||||||
for batch_x, batch_y in pbar:
|
for batch_data, batch_y in pbar:
|
||||||
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
|
# Move all body parts to device
|
||||||
|
batch_data = tuple(d.to(device) for d in batch_data)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(batch_x)
|
output = model(batch_data)
|
||||||
loss = criterion(output, batch_y)
|
loss = criterion(output, batch_y)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -232,9 +335,11 @@ if __name__ == "__main__":
|
|||||||
val_loss = 0
|
val_loss = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for vx, vy in tqdm(val_loader, desc=f"Epoch {epoch + 1}/25 [Val]"):
|
for batch_data, vy in tqdm(val_loader, desc=f"Epoch {epoch + 1}/25 [Val]"):
|
||||||
vx, vy = vx.to(device), vy.to(device)
|
batch_data = tuple(d.to(device) for d in batch_data)
|
||||||
output = model(vx)
|
vy = vy.to(device)
|
||||||
|
|
||||||
|
output = model(batch_data)
|
||||||
val_loss += criterion(output, vy).item()
|
val_loss += criterion(output, vy).item()
|
||||||
pred = output.argmax(1)
|
pred = output.argmax(1)
|
||||||
val_correct += (pred == vy).sum().item()
|
val_correct += (pred == vy).sum().item()
|
||||||
@@ -254,11 +359,11 @@ if __name__ == "__main__":
|
|||||||
# Save best model
|
# Save best model
|
||||||
if val_acc > best_acc:
|
if val_acc > best_acc:
|
||||||
best_acc = val_acc
|
best_acc = val_acc
|
||||||
torch.save(model.state_dict(), "best_asl_model.pth")
|
torch.save(model.state_dict(), "best_asl_model_separate.pth")
|
||||||
print(f" ✓ Best model saved! (Val Acc: {val_acc:.2f}%)\n")
|
print(f" ✓ Best model saved! (Val Acc: {val_acc:.2f}%)\n")
|
||||||
|
|
||||||
# Checkpoint every 5 epochs
|
# Checkpoint every 5 epochs
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
torch.save(model.state_dict(), f"asl_model_e{epoch + 1}.pth")
|
torch.save(model.state_dict(), f"asl_model_separate_e{epoch + 1}.pth")
|
||||||
|
|
||||||
print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")
|
print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")
|
||||||
Reference in New Issue
Block a user