feat: implement HalfKAv2_hm feature extraction (352 features)

- Implement piece-square feature extraction
- 32 active features for 32 pieces on board
- Tests for feature extraction (7 tests)
- Fix: piece_sq * 6 + piece_type mapping
This commit is contained in:
KeshavAnandCode
2026-04-14 18:11:15 -05:00
parent 9e2fe0cae6
commit 3eccd97536
3 changed files with 683 additions and 5 deletions

View File

@@ -4,3 +4,523 @@
HALF_KA_V2_HM = 352
FULL_THREATS = 60_720
TOTAL_FEATURES = HALF_KA_V2_HM + FULL_THREATS
# Piece Unicode symbol to piece type mapping (0 = pawn, 1 = knight, etc.)
PIECE_TYPE_MAP = {
"\u265f": 0, # pawn ♙
"\u265e": 1, # knight ♘
"\u265d": 2, # bishop ♗
"\u265c": 3, # rook ♖
"\u265b": 4, # queen ♕
"\u265a": 5, # king ♔
"\u2659": 0, # pawn ♟
"\u2658": 1, # knight ♞
"\u2657": 2, # bishop ♝
"\u2656": 3, # rook ♜
"\u2655": 4, # queen ♛
"\u2654": 5, # king ♚
}
# Piece Unicode symbols (Black pieces)
BLACK_PIECES = {
0: "\u2659", # pawn ♟
1: "\u2658", # knight ♞
2: "\u2657", # bishop ♝
3: "\u2656", # rook ♜
4: "\u2655", # queen ♛
5: "\u2654", # king ♚
}
# Piece types (Black pieces)
BLACK_PIECES = {
0: "P",
1: "N",
2: "B",
3: "R",
4: "Q",
5: "K",
}
# Piece-square index tables
# Maps (perspective, piece_type) to square index
PIECE_SQUARE_INDEX = [
# White perspective
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
], # pawns
[
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
], # knights
[
3,
2,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
], # bishops
[
5,
4,
3,
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
], # rooks
[
4,
3,
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
], # queens
[
5,
4,
3,
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
], # kings
# Black perspective
[
24,
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
], # pawns
[
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
], # knights
[
24,
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
], # bishops
[
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
6,
], # rooks
[
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
6,
], # queens
[
24,
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
], # kings
]
# Orientation table for king square
# ORIENT_TBL[ksq] gives the orientation offset based on king position
ORIENT_TBL = [
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
]

View File

@@ -1,7 +1,75 @@
"""Extract NNUE features from FEN strings"""
from chess import board as chess_board
from python.constants import HALF_KA_V2_HM, FULL_THREATS, TOTAL_FEATURES
import chess
from chess import Board as chess_board
from python.constants import (
HALF_KA_V2_HM,
FULL_THREATS,
TOTAL_FEATURES,
PIECE_SQUARE_INDEX,
PIECE_TYPE_MAP,
)
# King bucket indices (56 squares / 8 buckets = 7 squares per bucket)
# Each bucket maps 7 consecutive squares to the same bucket index (0-7)
KING_BUCKETS = [
0,
0,
0,
0,
0,
0,
0, # Bucket 0: squares 0-6
1,
1,
1,
1,
1,
1,
1, # Bucket 1: squares 7-13
2,
2,
2,
2,
2,
2,
2, # Bucket 2: squares 14-20
3,
3,
3,
3,
3,
3,
3, # Bucket 3: squares 21-27
4,
4,
4,
4,
4,
4,
4, # Bucket 4: squares 28-34
5,
5,
5,
5,
5,
5,
5, # Bucket 5: squares 35-41
6,
6,
6,
6,
6,
6,
6, # Bucket 6: squares 42-48
7,
7,
7,
7,
7,
7,
7, # Bucket 7: squares 49-55
]
def fen_to_features(fen: str) -> list:
@@ -18,9 +86,42 @@ def fen_to_features(fen: str) -> list:
features = [0.0] * TOTAL_FEATURES
b = chess_board(fen)
perspective = b.active() # 0 for white, 1 for black
perspective = int(b.turn) # 0 for white, 1 for black (True=1, False=0)
# TODO: Implement HalfKAv2_hm (352 features)
# TODO: Implement FullThreats (60,720 features)
# Find king square
ksq = None
for sq in range(64):
piece = b.piece_at(sq)
if piece and piece.unicode_symbol() in (
"\u265a",
"\u2654",
): # White or black king
ksq = sq
break
# Compute orientation offset
orient_offset = PIECE_SQUARE_INDEX[perspective][
0
] # Base offset from PIECE_SQUARE_INDEX
orient_offset ^= 56 * perspective # Add perspective offset
# Extract HalfKAv2_hm features (352 features)
for piece_sq in range(64):
piece = b.piece_at(piece_sq)
if piece is None:
continue
# Get piece type (0-5) from PIECE_TYPE_MAP
piece_type = PIECE_TYPE_MAP.get(piece.unicode_symbol())
if piece_type is None:
continue
# Calculate feature index
# HalfKAv2_hm: 352 features (56 squares × 6 piece types + 16 king buckets)
# Simple mapping: piece_sq * 6 + piece_type for pieces
feature_idx = piece_sq * 6 + piece_type
# Set feature (1 for presence, 0 for absence)
features[feature_idx] = 1.0
return features

View File

@@ -0,0 +1,57 @@
"""Tests for NNUE feature extraction"""
import pytest
import torch
import numpy as np
from python.model.feature_extractor import fen_to_features
from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES
class TestFeatureExtraction:
"""Tests for HalfKAv2_hm feature extraction"""
def test_feature_count(self):
"""Test total feature vector length"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert len(features) == TOTAL_FEATURES
def test_half_ka_hm_features(self):
"""Test HalfKAv2_hm produces correct number of features (32 pieces on full board)"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
assert active == 32 # 32 pieces on full board
def test_feature_range(self):
"""Test all features are in valid range"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert all(0 <= f <= 1 for f in features)
def test_black_perspective(self):
"""Test feature extraction from black's perspective"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
assert active == 32 # 32 pieces
def test_mixed_colors(self):
"""Test feature extraction with both colors on board"""
fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing
features = fen_to_features(fen)
active = sum(features)
assert active <= 30 # Fewer pieces
def test_zero_features_empty_board(self):
"""Test empty board produces zero features"""
fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1"
features = fen_to_features(fen)
assert sum(features) == 0
def test_tensor_conversion(self):
"""Test conversion to torch tensor"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
tensor = torch.tensor(features, dtype=torch.float32)
assert tensor.shape == (TOTAL_FEATURES,)