- Create python directory with data/, model/ subdirectories - Implement LinearEval(61072->1) model - Add config, constants, feature_extractor - Add tests with 4 passing test cases
78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
"""Training loop for NNUE linear model"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from python.model.nnue_linear import LinearEval
|
|
from python.model.feature_extractor import fen_to_features
|
|
from python.config import BATCH_SIZE, LEARNING_RATE, WEIGHT_DECAY, GRADIENT_CLIP, EPOCHS
|
|
|
|
|
|
def train(features: np.ndarray, labels: np.ndarray) -> LinearEval:
|
|
"""
|
|
Train the linear model.
|
|
|
|
Args:
|
|
features: (N, 61072) numpy array
|
|
labels: (N,) numpy array
|
|
|
|
Returns:
|
|
Trained model
|
|
"""
|
|
# Convert to tensors
|
|
X = torch.from_numpy(features).float()
|
|
y = torch.from_numpy(labels).float()
|
|
|
|
# Create dataset and dataloader
|
|
dataset = TensorDataset(X, y)
|
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
# Initialize model
|
|
model = LinearEval()
|
|
optimizer = torch.optim.Adam(
|
|
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
|
|
)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
|
|
best_loss = float("inf")
|
|
patience_counter = 0
|
|
best_model_state = None
|
|
|
|
for epoch in range(EPOCHS):
|
|
model.train()
|
|
total_loss = 0.0
|
|
|
|
for batch_X, batch_y in dataloader:
|
|
optimizer.zero_grad()
|
|
preds = model(batch_X)
|
|
loss = torch.nn.functional.mse_loss(preds, batch_y)
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
scheduler.step()
|
|
|
|
# Early stopping check
|
|
if avg_loss < best_loss:
|
|
best_loss = avg_loss
|
|
best_model_state = model.state_dict().copy()
|
|
patience_counter = 0
|
|
else:
|
|
patience_counter += 1
|
|
|
|
if (epoch + 1) % 10 == 0:
|
|
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
|
|
|
|
if patience_counter >= 50:
|
|
print("Early stopping triggered")
|
|
break
|
|
|
|
# Load best model
|
|
if best_model_state is not None:
|
|
model.load_state_dict(best_model_state)
|
|
|
|
return model
|