Files
chess-engine/python/python/train.py
KeshavAnandCode 9e2fe0cae6 feat: add project structure and basic NNUE model
- Create python directory with data/, model/ subdirectories
- Implement LinearEval(61072->1) model
- Add config, constants, feature_extractor
- Add tests with 4 passing test cases
2026-04-14 18:03:42 -05:00

78 lines
2.2 KiB
Python

"""Training loop for NNUE linear model"""
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from python.model.nnue_linear import LinearEval
from python.model.feature_extractor import fen_to_features
from python.config import BATCH_SIZE, LEARNING_RATE, WEIGHT_DECAY, GRADIENT_CLIP, EPOCHS
def train(features: np.ndarray, labels: np.ndarray) -> LinearEval:
"""
Train the linear model.
Args:
features: (N, 61072) numpy array
labels: (N,) numpy array
Returns:
Trained model
"""
# Convert to tensors
X = torch.from_numpy(features).float()
y = torch.from_numpy(labels).float()
# Create dataset and dataloader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model
model = LinearEval()
optimizer = torch.optim.Adam(
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_loss = float("inf")
patience_counter = 0
best_model_state = None
for epoch in range(EPOCHS):
model.train()
total_loss = 0.0
for batch_X, batch_y in dataloader:
optimizer.zero_grad()
preds = model(batch_X)
loss = torch.nn.functional.mse_loss(preds, batch_y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
scheduler.step()
# Early stopping check
if avg_loss < best_loss:
best_loss = avg_loss
best_model_state = model.state_dict().copy()
patience_counter = 0
else:
patience_counter += 1
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
if patience_counter >= 50:
print("Early stopping triggered")
break
# Load best model
if best_model_state is not None:
model.load_state_dict(best_model_state)
return model