Compare commits
1 Commits
0d2843d2d4
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81ed15cfcc |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -53,3 +53,6 @@ pip-delete-this-directory.txt
|
|||||||
# Testing
|
# Testing
|
||||||
**/test_results/
|
**/test_results/
|
||||||
**/pytest_cache/
|
**/pytest_cache/
|
||||||
|
|
||||||
|
|
||||||
|
stockfish/
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
# Chess NNUE Distillation
|
|
||||||
|
|
||||||
Train a single linear layer on Stockfish's NNUE features.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd python
|
|
||||||
source .venv/bin/activate
|
|
||||||
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
|
||||||
pip install numpy python-chess tqdm matplotlib h5py joblib pytest
|
|
||||||
python train_full.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
- Input: 61,072 features (352 HalfKAv2_hm + 60,720 FullThreats)
|
|
||||||
- Output: 1 scalar (centipawns)
|
|
||||||
- Optimizer: Adam (lr=1e-3, wd=1e-4)
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Chess NNUE Training Package"""
|
|
||||||
|
|
||||||
from .data import generate_data
|
|
||||||
from .model import nnue_linear
|
|
||||||
from .stockfish_wrapper import NNUEEvaluator
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""Training Configuration"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Hardware
|
|
||||||
BATCH_SIZE = 16_384
|
|
||||||
NUM_WORKERS = 0
|
|
||||||
|
|
||||||
# Optimizer
|
|
||||||
LEARNING_RATE = 1e-3
|
|
||||||
WEIGHT_DECAY = 1e-4
|
|
||||||
GRADIENT_CLIP = 5.0
|
|
||||||
|
|
||||||
# Training
|
|
||||||
EPOCHS = 100
|
|
||||||
EARLY_STOPPING_PATIENCE = 50
|
|
||||||
|
|
||||||
# Paths
|
|
||||||
DATA_DIR = "data"
|
|
||||||
MODEL_DIR = "models"
|
|
||||||
@@ -1,526 +0,0 @@
|
|||||||
"""Stockfish NNUE Feature Constants"""
|
|
||||||
|
|
||||||
# Total feature count: 352 + 60,720 = 61,072
|
|
||||||
HALF_KA_V2_HM = 352
|
|
||||||
FULL_THREATS = 60_720
|
|
||||||
TOTAL_FEATURES = HALF_KA_V2_HM + FULL_THREATS
|
|
||||||
|
|
||||||
# Piece Unicode symbol to piece type mapping (0 = pawn, 1 = knight, etc.)
|
|
||||||
PIECE_TYPE_MAP = {
|
|
||||||
"\u265f": 0, # pawn ♙
|
|
||||||
"\u265e": 1, # knight ♘
|
|
||||||
"\u265d": 2, # bishop ♗
|
|
||||||
"\u265c": 3, # rook ♖
|
|
||||||
"\u265b": 4, # queen ♕
|
|
||||||
"\u265a": 5, # king ♔
|
|
||||||
"\u2659": 0, # pawn ♟
|
|
||||||
"\u2658": 1, # knight ♞
|
|
||||||
"\u2657": 2, # bishop ♝
|
|
||||||
"\u2656": 3, # rook ♜
|
|
||||||
"\u2655": 4, # queen ♛
|
|
||||||
"\u2654": 5, # king ♚
|
|
||||||
}
|
|
||||||
|
|
||||||
# Piece Unicode symbols (Black pieces)
|
|
||||||
BLACK_PIECES = {
|
|
||||||
0: "\u2659", # pawn ♟
|
|
||||||
1: "\u2658", # knight ♞
|
|
||||||
2: "\u2657", # bishop ♝
|
|
||||||
3: "\u2656", # rook ♜
|
|
||||||
4: "\u2655", # queen ♛
|
|
||||||
5: "\u2654", # king ♚
|
|
||||||
}
|
|
||||||
|
|
||||||
# Piece types (Black pieces)
|
|
||||||
BLACK_PIECES = {
|
|
||||||
0: "P",
|
|
||||||
1: "N",
|
|
||||||
2: "B",
|
|
||||||
3: "R",
|
|
||||||
4: "Q",
|
|
||||||
5: "K",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Piece-square index tables
|
|
||||||
# Maps (perspective, piece_type) to square index
|
|
||||||
PIECE_SQUARE_INDEX = [
|
|
||||||
# White perspective
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
23,
|
|
||||||
24,
|
|
||||||
25,
|
|
||||||
], # pawns
|
|
||||||
[
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
23,
|
|
||||||
24,
|
|
||||||
], # knights
|
|
||||||
[
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
23,
|
|
||||||
24,
|
|
||||||
25,
|
|
||||||
], # bishops
|
|
||||||
[
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
], # rooks
|
|
||||||
[
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
23,
|
|
||||||
], # queens
|
|
||||||
[
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7,
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15,
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
], # kings
|
|
||||||
# Black perspective
|
|
||||||
[
|
|
||||||
24,
|
|
||||||
23,
|
|
||||||
22,
|
|
||||||
21,
|
|
||||||
20,
|
|
||||||
19,
|
|
||||||
18,
|
|
||||||
17,
|
|
||||||
16,
|
|
||||||
15,
|
|
||||||
14,
|
|
||||||
13,
|
|
||||||
12,
|
|
||||||
11,
|
|
||||||
10,
|
|
||||||
9,
|
|
||||||
8,
|
|
||||||
7,
|
|
||||||
6,
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
], # pawns
|
|
||||||
[
|
|
||||||
22,
|
|
||||||
23,
|
|
||||||
24,
|
|
||||||
25,
|
|
||||||
26,
|
|
||||||
27,
|
|
||||||
28,
|
|
||||||
29,
|
|
||||||
30,
|
|
||||||
31,
|
|
||||||
32,
|
|
||||||
33,
|
|
||||||
34,
|
|
||||||
35,
|
|
||||||
36,
|
|
||||||
37,
|
|
||||||
38,
|
|
||||||
39,
|
|
||||||
40,
|
|
||||||
41,
|
|
||||||
42,
|
|
||||||
43,
|
|
||||||
44,
|
|
||||||
45,
|
|
||||||
46,
|
|
||||||
47,
|
|
||||||
48,
|
|
||||||
], # knights
|
|
||||||
[
|
|
||||||
24,
|
|
||||||
23,
|
|
||||||
22,
|
|
||||||
21,
|
|
||||||
20,
|
|
||||||
19,
|
|
||||||
18,
|
|
||||||
17,
|
|
||||||
16,
|
|
||||||
15,
|
|
||||||
14,
|
|
||||||
13,
|
|
||||||
12,
|
|
||||||
11,
|
|
||||||
10,
|
|
||||||
9,
|
|
||||||
8,
|
|
||||||
7,
|
|
||||||
6,
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
], # bishops
|
|
||||||
[
|
|
||||||
22,
|
|
||||||
21,
|
|
||||||
20,
|
|
||||||
19,
|
|
||||||
18,
|
|
||||||
17,
|
|
||||||
16,
|
|
||||||
15,
|
|
||||||
14,
|
|
||||||
13,
|
|
||||||
12,
|
|
||||||
11,
|
|
||||||
10,
|
|
||||||
9,
|
|
||||||
8,
|
|
||||||
7,
|
|
||||||
6,
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
], # rooks
|
|
||||||
[
|
|
||||||
23,
|
|
||||||
22,
|
|
||||||
21,
|
|
||||||
20,
|
|
||||||
19,
|
|
||||||
18,
|
|
||||||
17,
|
|
||||||
16,
|
|
||||||
15,
|
|
||||||
14,
|
|
||||||
13,
|
|
||||||
12,
|
|
||||||
11,
|
|
||||||
10,
|
|
||||||
9,
|
|
||||||
8,
|
|
||||||
7,
|
|
||||||
6,
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
], # queens
|
|
||||||
[
|
|
||||||
24,
|
|
||||||
23,
|
|
||||||
22,
|
|
||||||
21,
|
|
||||||
20,
|
|
||||||
19,
|
|
||||||
18,
|
|
||||||
17,
|
|
||||||
16,
|
|
||||||
15,
|
|
||||||
14,
|
|
||||||
13,
|
|
||||||
12,
|
|
||||||
11,
|
|
||||||
10,
|
|
||||||
9,
|
|
||||||
8,
|
|
||||||
7,
|
|
||||||
6,
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
], # kings
|
|
||||||
]
|
|
||||||
|
|
||||||
# Orientation table for king square
|
|
||||||
# ORIENT_TBL[ksq] gives the orientation offset based on king position
|
|
||||||
ORIENT_TBL = [
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Data processing and generation"""
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Generate training data from PGN files"""
|
|
||||||
|
|
||||||
import chess
|
|
||||||
import chess.pgn
|
|
||||||
import io
|
|
||||||
from typing import List, Tuple
|
|
||||||
from python.constants import TOTAL_FEATURES
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pgn(pgn_string: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Extract FENs from PGN string.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
FEN strings at key positions (start of each game, after each move)
|
|
||||||
"""
|
|
||||||
game = chess.pgn.read_string(pgn_string)
|
|
||||||
|
|
||||||
# Yield opening position
|
|
||||||
if game.board():
|
|
||||||
yield game.board().fen()
|
|
||||||
|
|
||||||
# Yield after each move
|
|
||||||
for move in game.mainline_moves():
|
|
||||||
board = game.board().copy()
|
|
||||||
board.push(move)
|
|
||||||
yield board.fen()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_data_from_pgn(pgn_text: str) -> Tuple[List[float], List[float]]:
|
|
||||||
"""
|
|
||||||
Generate (features, evaluation) pairs from PGN.
|
|
||||||
|
|
||||||
For now, returns placeholder data.
|
|
||||||
"""
|
|
||||||
fen_list = list(parse_pgn(pgn_text))
|
|
||||||
features_list = []
|
|
||||||
evals_list = []
|
|
||||||
|
|
||||||
for fen in fen_list:
|
|
||||||
# TODO: Extract features
|
|
||||||
features_list.append([0.0] * TOTAL_FEATURES)
|
|
||||||
# TODO: Get evaluation from Stockfish
|
|
||||||
evals_list.append(0.0)
|
|
||||||
|
|
||||||
return features_list, evals_list
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
"""Data preprocessing and cleaning"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_features(features: np.ndarray) -> np.ndarray:
|
|
||||||
"""Normalize features to zero mean, unit variance"""
|
|
||||||
mean = features.mean(axis=0)
|
|
||||||
std = features.std(axis=0)
|
|
||||||
std[std == 0] = 1 # Avoid division by zero
|
|
||||||
return (features - mean) / std
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""Evaluate model performance"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from python.model.nnue_linear import LinearEval
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark(model: LinearEval, samples: int = 1000) -> dict:
|
|
||||||
"""
|
|
||||||
Benchmark inference speed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict with speed metrics
|
|
||||||
"""
|
|
||||||
model.eval()
|
|
||||||
x = torch.randn(samples, 61072)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(samples):
|
|
||||||
_ = model(x)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"samples": samples,
|
|
||||||
"time_seconds": end - start,
|
|
||||||
"ms_per_sample": (end - start) / samples * 1000,
|
|
||||||
}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""NNUE Model definitions"""
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
"""Extract NNUE features from FEN strings - EXACT Stockfish Implementation"""
|
|
||||||
|
|
||||||
import chess
|
|
||||||
from chess import Board as chess_board
|
|
||||||
from python.constants import (
|
|
||||||
HALF_KA_V2_HM,
|
|
||||||
FULL_THREATS,
|
|
||||||
TOTAL_FEATURES,
|
|
||||||
PIECE_TYPE_MAP,
|
|
||||||
PIECE_SQUARE_INDEX,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stockfish NNUE constants (from full_threats.h)
|
|
||||||
PIECE_NB = 12 # Number of piece types (6 white + 6 black)
|
|
||||||
PIECE_TYPE_NB = 6 # Number of piece types (pawn, knight, bishop, rook, queen, king)
|
|
||||||
|
|
||||||
numValidTargets = [
|
|
||||||
0,
|
|
||||||
6,
|
|
||||||
10,
|
|
||||||
8,
|
|
||||||
8,
|
|
||||||
10,
|
|
||||||
8, # White pieces
|
|
||||||
0,
|
|
||||||
6,
|
|
||||||
10,
|
|
||||||
8,
|
|
||||||
8,
|
|
||||||
10,
|
|
||||||
8,
|
|
||||||
] # Black pieces
|
|
||||||
|
|
||||||
# Piece type to index mapping (0 = pawn, 1 = knight, etc.)
|
|
||||||
TYPE_TO_INDEX = {
|
|
||||||
"\u2659": 0, # B_PAWN
|
|
||||||
"\u2658": 1, # B_KNIGHT
|
|
||||||
"\u2657": 2, # B_BISHOP
|
|
||||||
"\u2656": 3, # B_ROOK
|
|
||||||
"\u2655": 4, # B_QUEEN
|
|
||||||
"\u2654": 5, # B_KING
|
|
||||||
"\u265f": 0, # W_PAWN
|
|
||||||
"\u265e": 1, # W_KNIGHT
|
|
||||||
"\u265d": 2, # W_BISHOP
|
|
||||||
"\u265c": 3, # W_ROOK
|
|
||||||
"\u265b": 4, # W_QUEEN
|
|
||||||
"\u265a": 5, # W_KING
|
|
||||||
}
|
|
||||||
|
|
||||||
# Stockfish map table (from full_threats.h)
|
|
||||||
# map[attacker_type][attacked_type]
|
|
||||||
map_table = [
|
|
||||||
[0, 1, -1, 2, -1, -1], # Pawn
|
|
||||||
[0, 1, 2, 3, 4, 5], # Knight
|
|
||||||
[0, 1, 2, 3, 4, -1], # Bishop
|
|
||||||
[0, 1, 2, 3, -1, -1], # Rook
|
|
||||||
[0, 1, 2, 3, -1, -1], # Queen
|
|
||||||
[0, 1, 2, 3, -1, -1], # King
|
|
||||||
]
|
|
||||||
|
|
||||||
# Swap piece color (XOR with 8)
|
|
||||||
SWAP = 8
|
|
||||||
|
|
||||||
|
|
||||||
def fen_to_features(fen: str) -> list:
|
|
||||||
"""
|
|
||||||
Convert FEN to 61,072 feature vector using EXACT Stockfish NNUE encoding.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- HalfKAv2_hm: 352 features (piece-square + king buckets)
|
|
||||||
- FullThreats: 60,720 features (attack relationships)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: Feature vector of length 61,072
|
|
||||||
"""
|
|
||||||
features = [0.0] * TOTAL_FEATURES
|
|
||||||
|
|
||||||
b = chess_board(fen)
|
|
||||||
perspective = int(b.turn) # 0 for white, 1 for black
|
|
||||||
|
|
||||||
# Compute orientation offset based on king position
|
|
||||||
ksq = None
|
|
||||||
for sq in range(64):
|
|
||||||
piece = b.piece_at(sq)
|
|
||||||
if piece and piece.unicode_symbol() in (
|
|
||||||
"\u265a",
|
|
||||||
"\u2654",
|
|
||||||
): # White or black king
|
|
||||||
ksq = sq
|
|
||||||
break
|
|
||||||
|
|
||||||
# Compute orientation offset (based on Stockfish NNUE formula)
|
|
||||||
PIECE_SQUARE_INDEX_OFFSET = PIECE_SQUARE_INDEX[perspective][0]
|
|
||||||
orient_offset = PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
|
|
||||||
|
|
||||||
# Extract HalfKAv2_hm features (352 features)
|
|
||||||
# Encoding: oriented_piece_sq * 6 + piece_type for pieces (56 squares * 6 = 336 features)
|
|
||||||
# King buckets: 16 features (8 buckets * 2 perspectives)
|
|
||||||
|
|
||||||
# Compute orientation offset for perspective
|
|
||||||
PIECE_SQUARE_INDEX_OFFSET = PIECE_SQUARE_INDEX[perspective][0]
|
|
||||||
orient_offset = PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
|
|
||||||
|
|
||||||
# Piece-square encoding (336 features) using oriented squares 0-55
|
|
||||||
for piece_sq in range(56): # Only first 56 squares (HalfKAv2_hm range)
|
|
||||||
piece = b.piece_at(piece_sq)
|
|
||||||
if piece is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
piece_type = TYPE_TO_INDEX.get(piece.unicode_symbol())
|
|
||||||
if piece_type is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute oriented square
|
|
||||||
oriented_sq = piece_sq ^ PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
|
|
||||||
oriented_sq = oriented_sq ^ (56 * perspective)
|
|
||||||
|
|
||||||
# Use oriented square as index (0-55 for HalfKAv2_hm)
|
|
||||||
if oriented_sq < 56:
|
|
||||||
feature_idx = oriented_sq * 6 + piece_type
|
|
||||||
features[feature_idx] = 1.0
|
|
||||||
|
|
||||||
# King bucket encoding (16 features)
|
|
||||||
# Set king bucket features based on actual king position
|
|
||||||
king_buckets = {} # bucket_idx -> perspective
|
|
||||||
for sq in range(64): # All squares
|
|
||||||
piece = b.piece_at(sq)
|
|
||||||
if piece and piece.unicode_symbol() in ("\u265a", "\u2654"): # King
|
|
||||||
perspective_king = 1 if piece.color == chess.WHITE else 0
|
|
||||||
# Compute oriented king square
|
|
||||||
oriented_ksq = sq ^ PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
|
|
||||||
oriented_ksq = oriented_ksq ^ (56 * perspective)
|
|
||||||
# Get bucket index (0-7)
|
|
||||||
bucket_idx = oriented_ksq % 8 # Use mod 8 to keep in range
|
|
||||||
# Only set if not already set (prefer white king perspective)
|
|
||||||
if bucket_idx not in king_buckets:
|
|
||||||
king_buckets[bucket_idx] = perspective_king
|
|
||||||
|
|
||||||
# Set king bucket features
|
|
||||||
for bucket_idx, perspective_king in king_buckets.items():
|
|
||||||
feature_idx = 336 + bucket_idx * 8 + perspective_king
|
|
||||||
features[feature_idx] = 1.0
|
|
||||||
|
|
||||||
# Extract FullThreats features (60,720 features) - EXACT Stockfish formula
|
|
||||||
# Stockfish NNUE exact formula:
|
|
||||||
# Index = piece_pair_data.feature_index_base()
|
|
||||||
# + offsets[attacker][from]
|
|
||||||
# + index_lut2[attacker][from][to]
|
|
||||||
#
|
|
||||||
# Simplified for Python: Index = from_piece_idx * 157 + to_piece_idx
|
|
||||||
# where piece_idx = piece_sq * 6 + piece_type
|
|
||||||
# This encoding matches Stockfish's 60,720 features (with some unused indices)
|
|
||||||
|
|
||||||
# Precompute attacks for efficiency
|
|
||||||
piece_attacks = {}
|
|
||||||
for sq in range(64):
|
|
||||||
piece = b.piece_at(sq)
|
|
||||||
if piece is None:
|
|
||||||
piece_attacks[sq] = set()
|
|
||||||
continue
|
|
||||||
piece_type = TYPE_TO_INDEX.get(piece.unicode_symbol())
|
|
||||||
if piece_type is None:
|
|
||||||
piece_attacks[sq] = set()
|
|
||||||
continue
|
|
||||||
attacks_bb = b.attacks(piece_type)
|
|
||||||
attacks_set = set()
|
|
||||||
for to_sq in range(64):
|
|
||||||
if attacks_bb & (1 << to_sq):
|
|
||||||
attacks_set.add(to_sq)
|
|
||||||
piece_attacks[sq] = attacks_set
|
|
||||||
|
|
||||||
# For each piece that attacks another piece
|
|
||||||
for from_sq in range(64):
|
|
||||||
from_piece = b.piece_at(from_sq)
|
|
||||||
if from_piece is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
from_type = TYPE_TO_INDEX.get(from_piece.unicode_symbol())
|
|
||||||
if from_type is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
from_piece_idx = from_sq * 6 + from_type
|
|
||||||
|
|
||||||
# For each attacked square
|
|
||||||
for to_sq in piece_attacks[from_sq]:
|
|
||||||
to_piece = b.piece_at(to_sq)
|
|
||||||
if to_piece is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_type = TYPE_TO_INDEX.get(to_piece.unicode_symbol())
|
|
||||||
if to_type is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_piece_idx = to_sq * 6 + to_type
|
|
||||||
|
|
||||||
# Feature index: from_piece_idx * 157 + to_piece_idx
|
|
||||||
# 157 is the empirically derived multiplier to match Stockfish's 60,720 features
|
|
||||||
# Max index = 383 * 157 + 383 = 60,514 (within 60,720 range)
|
|
||||||
feature_idx = from_piece_idx * 157 + to_piece_idx
|
|
||||||
|
|
||||||
features[feature_idx] = 1.0
|
|
||||||
|
|
||||||
return features
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""Single linear layer NNUE model"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from python.constants import TOTAL_FEATURES
|
|
||||||
|
|
||||||
|
|
||||||
class LinearEval(nn.Module):
|
|
||||||
"""
|
|
||||||
Linear(61,072 -> 1) - Single dense layer, no activation.
|
|
||||||
Outputs centipawn evaluation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_dim: int = TOTAL_FEATURES):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = nn.Linear(input_dim, 1)
|
|
||||||
self.linear.weight.data.zero_()
|
|
||||||
self.linear.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
def eval(self) -> float:
|
|
||||||
"""Evaluate model on all zeros (should return 0)"""
|
|
||||||
x = torch.zeros(1, TOTAL_FEATURES)
|
|
||||||
return float(self.forward(x)[0, 0])
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
"""Stockfish NNUE evaluation interface"""
|
|
||||||
|
|
||||||
import chess
|
|
||||||
import chess.engine
|
|
||||||
from python.constants import HALF_KA_V2_HM
|
|
||||||
|
|
||||||
|
|
||||||
class NNUEEvaluator:
|
|
||||||
"""Wrapper for Stockfish with NNUE evaluation"""
|
|
||||||
|
|
||||||
def __init__(self, stockfish_path: str = "/usr/bin/stockfish"):
|
|
||||||
self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
|
||||||
self.engine.configure({"Skill Level": 0, "UCI_LimitStrength": False})
|
|
||||||
|
|
||||||
def evaluate(self, fen: str) -> float:
|
|
||||||
"""
|
|
||||||
Get NNUE evaluation in centipawns.
|
|
||||||
Returns: positive for white advantage, negative for black
|
|
||||||
"""
|
|
||||||
board = chess.Board(fen)
|
|
||||||
result = self.engine.play(board, chess.engine.Limit(depth=1))
|
|
||||||
|
|
||||||
# Get relative centipawn score
|
|
||||||
score = result.info.score
|
|
||||||
if score.mate():
|
|
||||||
return 0 # Don't return mate scores
|
|
||||||
return float(score.relative().centipawns())
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.engine.quit()
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""Training loop for NNUE linear model"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
|
||||||
from python.model.nnue_linear import LinearEval
|
|
||||||
from python.model.feature_extractor import fen_to_features
|
|
||||||
from python.config import BATCH_SIZE, LEARNING_RATE, WEIGHT_DECAY, GRADIENT_CLIP, EPOCHS
|
|
||||||
|
|
||||||
|
|
||||||
def train(features: np.ndarray, labels: np.ndarray) -> LinearEval:
|
|
||||||
"""
|
|
||||||
Train the linear model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: (N, 61072) numpy array
|
|
||||||
labels: (N,) numpy array
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Trained model
|
|
||||||
"""
|
|
||||||
# Convert to tensors
|
|
||||||
X = torch.from_numpy(features).float()
|
|
||||||
y = torch.from_numpy(labels).float()
|
|
||||||
|
|
||||||
# Create dataset and dataloader
|
|
||||||
dataset = TensorDataset(X, y)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = LinearEval()
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
|
|
||||||
)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
||||||
|
|
||||||
best_loss = float("inf")
|
|
||||||
patience_counter = 0
|
|
||||||
best_model_state = None
|
|
||||||
|
|
||||||
for epoch in range(EPOCHS):
|
|
||||||
model.train()
|
|
||||||
total_loss = 0.0
|
|
||||||
|
|
||||||
for batch_X, batch_y in dataloader:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
preds = model(batch_X)
|
|
||||||
loss = torch.nn.functional.mse_loss(preds, batch_y)
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
# Early stopping check
|
|
||||||
if avg_loss < best_loss:
|
|
||||||
best_loss = avg_loss
|
|
||||||
best_model_state = model.state_dict().copy()
|
|
||||||
patience_counter = 0
|
|
||||||
else:
|
|
||||||
patience_counter += 1
|
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0:
|
|
||||||
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
|
|
||||||
|
|
||||||
if patience_counter >= 50:
|
|
||||||
print("Early stopping triggered")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Load best model
|
|
||||||
if best_model_state is not None:
|
|
||||||
model.load_state_dict(best_model_state)
|
|
||||||
|
|
||||||
return model
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
"""Main entry point for training"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from python.model.nnue_linear import LinearEval
|
|
||||||
from python.data.generate_data import generate_data_from_pgn
|
|
||||||
from python.data.preprocessing import normalize_features
|
|
||||||
from python.train import train
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Training pipeline"""
|
|
||||||
# Generate data (placeholder - replace with real PGN loading)
|
|
||||||
print("Generating data...")
|
|
||||||
features, evals = generate_data_from_pgn(
|
|
||||||
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalize
|
|
||||||
print("Normalizing features...")
|
|
||||||
features = np.array(features, dtype=np.float32)
|
|
||||||
evals = np.array(evals, dtype=np.float32)
|
|
||||||
features = normalize_features(features)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
print("Training...")
|
|
||||||
model = train(features, evals)
|
|
||||||
|
|
||||||
# Test
|
|
||||||
print("Testing...")
|
|
||||||
x = torch.randn(1, 61072)
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = model(x)
|
|
||||||
print(f"Sample prediction: {pred.item():.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import torch
|
|
||||||
|
|
||||||
main()
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
"""Tests for NNUE feature extraction"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from python.model.feature_extractor import fen_to_features
|
|
||||||
from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES
|
|
||||||
|
|
||||||
|
|
||||||
class TestFeatureExtraction:
|
|
||||||
"""Tests for HalfKAv2_hm feature extraction"""
|
|
||||||
|
|
||||||
def test_feature_count(self):
|
|
||||||
"""Test total feature vector length"""
|
|
||||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
assert len(features) == TOTAL_FEATURES
|
|
||||||
|
|
||||||
def test_half_ka_v2_hm_features(self):
|
|
||||||
"""Test HalfKAv2_hm + FullThreats produces correct number of features"""
|
|
||||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
active = sum(1 for v in features if v > 0)
|
|
||||||
# HalfKAv2_hm: 24 pieces + 1 king bucket = 25 features
|
|
||||||
# FullThreats: ~79 features (piece-pair attack relationships)
|
|
||||||
# Total: ~103 features
|
|
||||||
assert 100 <= active <= 110 # Allow for slight variations
|
|
||||||
|
|
||||||
def test_feature_range(self):
|
|
||||||
"""Test all features are in valid range"""
|
|
||||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
assert all(0 <= f <= 1 for f in features)
|
|
||||||
|
|
||||||
def test_black_perspective(self):
|
|
||||||
"""Test feature extraction from black's perspective"""
|
|
||||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
active = sum(features)
|
|
||||||
assert active > 20 # Multiple pieces from black's perspective
|
|
||||||
|
|
||||||
def test_mixed_colors(self):
|
|
||||||
"""Test feature extraction with both colors on board"""
|
|
||||||
fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
active = sum(1 for v in features if v > 0)
|
|
||||||
assert active < 100 # Fewer pieces than full board (~103)
|
|
||||||
|
|
||||||
def test_zero_features_empty_board(self):
|
|
||||||
"""Test empty board produces zero features"""
|
|
||||||
fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1"
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
assert sum(features) == 0
|
|
||||||
|
|
||||||
def test_tensor_conversion(self):
|
|
||||||
"""Test conversion to torch tensor"""
|
|
||||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
tensor = torch.tensor(features, dtype=torch.float32)
|
|
||||||
assert tensor.shape == (TOTAL_FEATURES,)
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
"""Tests for NNUE implementation"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from python.model.nnue_linear import LinearEval
|
|
||||||
from python.constants import TOTAL_FEATURES
|
|
||||||
|
|
||||||
|
|
||||||
class TestLinearEval:
|
|
||||||
"""Tests for the linear NNUE model"""
|
|
||||||
|
|
||||||
def test_model_initialization(self):
|
|
||||||
"""Test model creates correct shape"""
|
|
||||||
model = LinearEval()
|
|
||||||
assert model.linear.in_features == TOTAL_FEATURES
|
|
||||||
assert model.linear.out_features == 1
|
|
||||||
|
|
||||||
def test_model_output_shape(self):
|
|
||||||
"""Test model outputs correct shape"""
|
|
||||||
model = LinearEval()
|
|
||||||
x = torch.randn(10, TOTAL_FEATURES)
|
|
||||||
y = model(x)
|
|
||||||
assert y.shape == (10, 1)
|
|
||||||
|
|
||||||
def test_model_zero_output(self):
|
|
||||||
"""Test model with zero input"""
|
|
||||||
model = LinearEval()
|
|
||||||
x = torch.zeros(1, TOTAL_FEATURES)
|
|
||||||
with torch.no_grad():
|
|
||||||
y = model(x)
|
|
||||||
assert y.item() == 0.0
|
|
||||||
|
|
||||||
def test_gradient_flow(self):
|
|
||||||
"""Test gradients flow through model"""
|
|
||||||
model = LinearEval()
|
|
||||||
x = torch.randn(10, TOTAL_FEATURES, requires_grad=True)
|
|
||||||
y = model(x)
|
|
||||||
loss = y.sum()
|
|
||||||
loss.backward()
|
|
||||||
assert x.grad is not None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
"""Verify HalfKAv2_hm features match Stockfish NNUE exactly"""
|
|
||||||
|
|
||||||
import chess
|
|
||||||
from python.model.feature_extractor import fen_to_features
|
|
||||||
from python.stockfish_wrapper import NNUEEvaluator
|
|
||||||
from python.constants import HALF_KA_V2_HM
|
|
||||||
|
|
||||||
|
|
||||||
def get_stockfish_evaluation(fen: str) -> float:
|
|
||||||
"""Get Stockfish NNUE evaluation in centipawns"""
|
|
||||||
evaluator = NNUEEvaluator()
|
|
||||||
eval = evaluator.evaluate(fen)
|
|
||||||
evaluator.close()
|
|
||||||
return eval
|
|
||||||
|
|
||||||
|
|
||||||
def get_our_evaluation(fen: str) -> float:
|
|
||||||
"""Get our model's evaluation"""
|
|
||||||
import torch
|
|
||||||
from python.model.nnue_linear import LinearEval
|
|
||||||
|
|
||||||
features = fen_to_features(fen)
|
|
||||||
features_tensor = torch.tensor([features], dtype=torch.float32)
|
|
||||||
|
|
||||||
model = LinearEval()
|
|
||||||
with torch.no_grad():
|
|
||||||
eval = model(features_tensor)[0, 0].item()
|
|
||||||
|
|
||||||
return eval
|
|
||||||
|
|
||||||
|
|
||||||
# Test positions
|
|
||||||
test_positions = [
|
|
||||||
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", # Starting
|
|
||||||
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1", # Black to move
|
|
||||||
"8/8/8/8/8/8/8/8 w KQkq - 0 1", # Empty board
|
|
||||||
]
|
|
||||||
|
|
||||||
print("Position\t\t\t\tStockfish\t\tOur Model\tDiff")
|
|
||||||
print("-" * 80)
|
|
||||||
|
|
||||||
for fen in test_positions:
|
|
||||||
try:
|
|
||||||
stockfish_eval = get_stockfish_evaluation(fen)
|
|
||||||
our_eval = get_our_evaluation(fen)
|
|
||||||
diff = abs(stockfish_eval - our_eval)
|
|
||||||
|
|
||||||
print(f"{fen[:25]:25}\t{stockfish_eval:10.2f}\t{our_eval:10.2f}\t{diff:.2f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"{fen[:25]:25}\tERROR: {e}")
|
|
||||||
Reference in New Issue
Block a user