diff --git a/training.py b/training.py index 6b39e41..aa08a8c 100644 --- a/training.py +++ b/training.py @@ -154,32 +154,59 @@ print("Loading Kaggle ASL dataset...") base_path = 'asl_kaggle' train_df, sign_to_idx = load_kaggle_asl_data(base_path) -# Process landmarks -X = [] -y = [] +# Process landmarks with parallel processing +from multiprocessing import Pool, cpu_count +from functools import partial -print("\nProcessing landmark files...") -for idx, row in train_df.iterrows(): - if idx % 1000 == 0: - print(f"Processed {idx}/{len(train_df)} sequences...") +def process_single_sequence(row, base_path): + """Process a single sequence - designed for parallel execution""" parquet_path = os.path.join(base_path, row['path']) if not os.path.exists(parquet_path): - continue + return None, None - landmarks = extract_hand_landmarks_from_parquet(parquet_path) + try: + landmarks = extract_hand_landmarks_from_parquet(parquet_path) - if landmarks is None: - continue + if landmarks is None: + return None, None - features = get_optimized_features(landmarks) + features = get_optimized_features(landmarks) - if features is None: - continue + if features is None: + return None, None - X.append(features) - y.append(row['sign']) + return features, row['sign'] + except Exception as e: + return None, None + + +print("\nProcessing landmark files with parallel processing...") +print(f"Using {cpu_count()} CPU cores") + +# Convert DataFrame rows to list for parallel processing +rows_list = [row for _, row in train_df.iterrows()] + +# Create partial function with base_path +process_func = partial(process_single_sequence, base_path=base_path) + +# Process in parallel with progress updates +X = [] +y = [] +batch_size = 1000 + +with Pool(processes=cpu_count()) as pool: + for i in range(0, len(rows_list), batch_size): + batch = rows_list[i:i + batch_size] + results = pool.map(process_func, batch) + + for features, sign in results: + if features is not None and sign is not None: + X.append(features) + y.append(sign) + + print(f"Processed {min(i + batch_size, len(rows_list))}/{len(rows_list)} sequences... (Valid: {len(X)})") print(f"\nSuccessfully processed {len(X)} sequences")