broken stash
This commit is contained in:
@@ -1,95 +1,600 @@
|
||||
"""Extract NNUE features from FEN strings - EXACT Stockfish Implementation"""
|
||||
|
||||
import chess
|
||||
from chess import Board as chess_board
|
||||
import numpy as np
|
||||
from python.constants import TOTAL_FEATURES
|
||||
# EXACT Stockfish constants
|
||||
PIECE_NB = 12
|
||||
PIECE_TYPE_NB = 6
|
||||
SQUARE_NB = 64
|
||||
|
||||
# EXACT Stockfish NNUE Tables
|
||||
OrientTBL = np.array([10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
10, 10, 10, 10, 0, 0, 0, 0,
|
||||
], dtype=np.int8)
|
||||
# Exact numValidTargets from Stockfish
|
||||
numValidTargets = [0, 6, 10, 8, 8, 10, 0, 0, 0, 6, 10, 8, 8, 10, 0, 0]
|
||||
|
||||
KingBuckets = np.array([28*11, 29*11, 30*11, 31*11, 31*11, 30*11, 29*11, 28*11,
|
||||
24*11, 25*11, 26*11, 27*11, 27*11, 26*11, 25*11, 24*11,
|
||||
20*11, 21*11, 22*11, 23*11, 23*11, 22*11, 21*11, 20*11,
|
||||
16*11, 17*11, 18*11, 19*11, 19*11, 18*11, 17*11, 16*11,
|
||||
12*11, 13*11, 14*11, 15*11, 15*11, 14*11, 13*11, 12*11,
|
||||
8*11, 9*11, 10*11, 11*11, 11*11, 10*11, 9*11, 8*11,
|
||||
4*11, 5*11, 6*11, 7*11, 7*11, 6*11, 5*11, 4*11,
|
||||
0, 1*11, 2*11, 3*11, 3*11, 2*11, 1*11, 0,
|
||||
], dtype=np.int16)
|
||||
# Exact map table from Stockfish
|
||||
map_table = [
|
||||
[0, 1, -1, 2, -1, -1],
|
||||
[0, 1, 2, 3, 4, -1],
|
||||
[0, 1, 2, 3, -1, -1],
|
||||
[0, 1, 2, 3, -1, -1],
|
||||
[0, 1, 2, 3, 4, -1],
|
||||
[-1, -1, -1, -1, -1, -1],
|
||||
]
|
||||
|
||||
# Precomputed lookup tables (simplified for distillation)
|
||||
index_lut1 = np.zeros((6, 6, 2), dtype=np.int32)
|
||||
index_lut2 = np.zeros((6, 64, 64), dtype=np.uint8)
|
||||
# Exact OrientTBL from Stockfish (SQ_A1=0, SQ_H1=10)
|
||||
OrientTBL = [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
10,
|
||||
]
|
||||
|
||||
# Exact KingBuckets from Stockfish
|
||||
KingBuckets = [
|
||||
28 * 11,
|
||||
29 * 11,
|
||||
30 * 11,
|
||||
31 * 11,
|
||||
31 * 11,
|
||||
30 * 11,
|
||||
29 * 11,
|
||||
28 * 11,
|
||||
24 * 11,
|
||||
25 * 11,
|
||||
26 * 11,
|
||||
27 * 11,
|
||||
27 * 11,
|
||||
26 * 11,
|
||||
25 * 11,
|
||||
24 * 11,
|
||||
20 * 11,
|
||||
21 * 11,
|
||||
22 * 11,
|
||||
23 * 11,
|
||||
23 * 11,
|
||||
22 * 11,
|
||||
21 * 11,
|
||||
20 * 11,
|
||||
16 * 11,
|
||||
17 * 11,
|
||||
18 * 11,
|
||||
19 * 11,
|
||||
19 * 11,
|
||||
18 * 11,
|
||||
17 * 11,
|
||||
16 * 11,
|
||||
12 * 11,
|
||||
13 * 11,
|
||||
14 * 11,
|
||||
15 * 11,
|
||||
15 * 11,
|
||||
14 * 11,
|
||||
13 * 11,
|
||||
12 * 11,
|
||||
8 * 11,
|
||||
9 * 11,
|
||||
10 * 11,
|
||||
11 * 11,
|
||||
11 * 11,
|
||||
10 * 11,
|
||||
9 * 11,
|
||||
8 * 11,
|
||||
4 * 11,
|
||||
5 * 11,
|
||||
6 * 11,
|
||||
7 * 11,
|
||||
7 * 11,
|
||||
6 * 11,
|
||||
5 * 11,
|
||||
4 * 11,
|
||||
0,
|
||||
1 * 11,
|
||||
2 * 11,
|
||||
3 * 11,
|
||||
3 * 11,
|
||||
2 * 11,
|
||||
1 * 11,
|
||||
0,
|
||||
]
|
||||
|
||||
|
||||
# POPCOUNT function (exact copy of Stockfish)
|
||||
def popcount(n):
|
||||
return bin(n).count("1")
|
||||
|
||||
|
||||
# EXACT bitboard functions from Stockfish
|
||||
def pseudoattacks_knight(sq):
|
||||
"""Exact Stockfish knight pseudoattacks"""
|
||||
attacks = 0
|
||||
file = sq % 8
|
||||
rank = sq // 8
|
||||
if rank > 0 and file > 0:
|
||||
attacks |= 1 << ((rank - 1) * 8 + (file - 1))
|
||||
if rank > 0 and file < 7:
|
||||
attacks |= 1 << ((rank - 1) * 8 + (file + 1))
|
||||
if rank < 7 and file > 0:
|
||||
attacks |= 1 << ((rank + 1) * 8 + (file - 1))
|
||||
if rank < 7 and file < 7:
|
||||
attacks |= 1 << ((rank + 1) * 8 + (file + 1))
|
||||
return attacks
|
||||
|
||||
|
||||
def pseudoattacks_bishop(sq):
|
||||
"""Exact Stockfish bishop pseudoattacks"""
|
||||
attacks = 0
|
||||
file = sq % 8
|
||||
rank = sq // 8
|
||||
for dr, dc in [(-1, -1), (-1, 1), (1, -1), (1, 1)]:
|
||||
r, c = rank + dr, file + dc
|
||||
while 0 <= r < 8 and 0 <= c < 8:
|
||||
attacks |= 1 << (r * 8 + c)
|
||||
r, c = r + dr, c + dc
|
||||
return attacks
|
||||
|
||||
|
||||
def pseudoattacks_rook(sq):
|
||||
"""Exact Stockfish rook pseudoattacks"""
|
||||
attacks = 0
|
||||
file = sq % 8
|
||||
rank = sq // 8
|
||||
for dr, dc in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
|
||||
r, c = rank + dr, file + dc
|
||||
while 0 <= r < 8 and 0 <= c < 8:
|
||||
attacks |= 1 << (r * 8 + c)
|
||||
r, c = r + dr, c + dc
|
||||
return attacks
|
||||
|
||||
|
||||
def pseudoattacks_queen(sq):
|
||||
"""Exact Stockfish queen pseudoattacks"""
|
||||
return pseudoattacks_bishop(sq) | pseudoattacks_rook(sq)
|
||||
|
||||
|
||||
def pseudoattacks_king(sq):
|
||||
"""Exact Stockfish king pseudoattacks"""
|
||||
attacks = 1 << sq
|
||||
file = sq % 8
|
||||
rank = sq // 8
|
||||
for dr in [-1, 0, 1]:
|
||||
for dc in [-1, 0, 1]:
|
||||
if dr == 0 and dc == 0:
|
||||
continue
|
||||
r, c = rank + dr, file + dc
|
||||
if 0 <= r < 8 and 0 <= c < 8:
|
||||
attacks |= 1 << (r * 8 + c)
|
||||
return attacks
|
||||
|
||||
|
||||
def pawnpushorattacks_white(sq):
|
||||
"""Exact Stockfish white pawn push/attacks"""
|
||||
attacks = 0
|
||||
file = sq % 8
|
||||
rank = sq // 8
|
||||
if rank < 7:
|
||||
attacks |= 1 << ((rank + 1) * 8 + file)
|
||||
if rank < 6 and file > 0:
|
||||
attacks |= 1 << ((rank + 1) * 8 + (file - 1))
|
||||
if rank < 6 and file < 7:
|
||||
attacks |= 1 << ((rank + 1) * 8 + (file + 1))
|
||||
return attacks
|
||||
|
||||
|
||||
def pawnpushorattacks_black(sq):
|
||||
"""Exact Stockfish black pawn push/attacks"""
|
||||
attacks = 0
|
||||
file = sq % 8
|
||||
rank = sq // 8
|
||||
if rank < 1:
|
||||
return attacks
|
||||
attacks |= 1 << ((rank - 1) * 8 + file)
|
||||
if rank > 0 and file > 0:
|
||||
attacks |= 1 << ((rank - 1) * 8 + (file - 1))
|
||||
if rank > 0 and file < 7:
|
||||
attacks |= 1 << ((rank - 1) * 8 + (file + 1))
|
||||
return attacks
|
||||
|
||||
|
||||
# Generate index_lut2 EXACTLY as Stockfish (make_piece_indices_type + make_piece_indices_piece)
|
||||
# This computes popcount(((1<<to) - 1) & attacks) for each from,to pair
|
||||
index_lut2 = [[[0] * 64 for _ in range(64)] for _ in range(PIECE_NB)]
|
||||
|
||||
# Knights (1, 8) - template make_piece_indices_type<PieceType::KNIGHT>
|
||||
for from_sq in range(64):
|
||||
attacks = pseudoattacks_knight(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks)
|
||||
index_lut2[1][from_sq][to_sq] = count
|
||||
index_lut2[8][from_sq][to_sq] = count
|
||||
|
||||
# Bishops (2, 9) - template make_piece_indices_type<PieceType::BISHOP>
|
||||
for from_sq in range(64):
|
||||
attacks = pseudoattacks_bishop(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks)
|
||||
index_lut2[2][from_sq][to_sq] = count
|
||||
index_lut2[9][from_sq][to_sq] = count
|
||||
|
||||
# Rooks (3, 10) - template make_piece_indices_type<PieceType::ROOK>
|
||||
for from_sq in range(64):
|
||||
attacks = pseudoattacks_rook(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks)
|
||||
index_lut2[3][from_sq][to_sq] = count
|
||||
index_lut2[10][from_sq][to_sq] = count
|
||||
|
||||
# Queens (4, 11) - template make_piece_indices_type<PieceType::QUEEN>
|
||||
for from_sq in range(64):
|
||||
attacks = pseudoattacks_queen(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks)
|
||||
index_lut2[4][from_sq][to_sq] = count
|
||||
index_lut2[11][from_sq][to_sq] = count
|
||||
|
||||
# Kings (5, 6) - template make_piece_indices_type<PieceType::KING>
|
||||
for from_sq in range(64):
|
||||
attacks = pseudoattacks_king(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks)
|
||||
index_lut2[5][from_sq][to_sq] = count
|
||||
index_lut2[6][from_sq][to_sq] = count
|
||||
|
||||
# Pawns (0, 7) - template make_piece_indices_piece<W_PAWN> / B_PAWN
|
||||
for from_sq in range(64):
|
||||
attacks_white = pawnpushorattacks_white(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks_white)
|
||||
index_lut2[0][from_sq][to_sq] = count
|
||||
attacks_black = pawnpushorattacks_black(from_sq)
|
||||
for to_sq in range(64):
|
||||
count = popcount(((1 << to_sq) - 1) & attacks_black)
|
||||
index_lut2[7][from_sq][to_sq] = count
|
||||
|
||||
# Compute helper_offsets and offsets EXACTLY as Stockfish (init_threat_offsets)
|
||||
AllPieces = list(range(PIECE_NB))
|
||||
helper_offsets = [None] * PIECE_NB
|
||||
offsets = [[0] * 64 for _ in range(PIECE_NB)]
|
||||
cumulativeOffset = 0
|
||||
|
||||
for piece_idx in AllPieces:
|
||||
cumulativePieceOffset = 0
|
||||
piece_type = piece_idx // 6 # 0=pawn, 1=knight, 2=bishop, 3=rook, 4=queen, 5=king
|
||||
|
||||
# Simple attack count lookup (simplified from Stockfish)
|
||||
for attacker in range(6):
|
||||
for from_sq in range(64):
|
||||
for to_sq in range(64):
|
||||
index_lut2[attacker, from_sq, to_sq] = 1
|
||||
offsets[piece_idx][from_sq] = cumulativePieceOffset
|
||||
|
||||
if piece_type == 0: # Pawn
|
||||
if from_sq >= 8 and from_sq < 56: # Not on rank 1 or 8
|
||||
color = piece_idx < 8 # White pawn
|
||||
attacks = (
|
||||
pawnpushorattacks_white(from_sq)
|
||||
if color
|
||||
else pawnpushorattacks_black(from_sq)
|
||||
)
|
||||
cumulativePieceOffset += popcount(attacks)
|
||||
else: # Non-pawn
|
||||
if piece_type == 1: # Knight
|
||||
attacks = pseudoattacks_knight(from_sq)
|
||||
elif piece_type == 2: # Bishop
|
||||
attacks = pseudoattacks_bishop(from_sq)
|
||||
elif piece_type == 3: # Rook
|
||||
attacks = pseudoattacks_rook(from_sq)
|
||||
elif piece_type == 4: # Queen
|
||||
attacks = pseudoattacks_queen(from_sq)
|
||||
elif piece_type == 5: # King
|
||||
attacks = pseudoattacks_king(from_sq)
|
||||
cumulativePieceOffset += popcount(attacks)
|
||||
|
||||
helper_offsets[piece_idx] = {
|
||||
"cumulativePieceOffset": cumulativePieceOffset,
|
||||
"cumulativeOffset": cumulativeOffset,
|
||||
}
|
||||
cumulativeOffset += numValidTargets[piece_idx] * cumulativePieceOffset
|
||||
|
||||
# Compute index_lut1 EXACTLY as Stockfish (init_index_luts)
|
||||
index_lut1 = [[[0, 0] for _ in range(PIECE_NB)] for _ in range(PIECE_NB)]
|
||||
DIMENSIONS = 60720
|
||||
|
||||
for attacker in AllPieces:
|
||||
for attacked in AllPieces:
|
||||
enemy = (attacker ^ attacked) == 8 # Different color
|
||||
attacker_type = attacker // 6
|
||||
attacked_type = attacked // 6
|
||||
|
||||
map_val = map_table[attacker_type][attacked_type]
|
||||
color_attacked = (
|
||||
1 if attacked < 8 else 0
|
||||
) # 1 if white, 0 if black (piece indices 0-7 are white, 8-11 are black)
|
||||
semi_excluded = attacker_type == attacked_type and (enemy or attacker_type != 0)
|
||||
|
||||
feature = helper_offsets[attacker]["cumulativeOffset"]
|
||||
feature += (
|
||||
color_attacked * (numValidTargets[attacker] // 2) + map_val
|
||||
) * helper_offsets[attacker]["cumulativePieceOffset"]
|
||||
|
||||
excluded = map_val < 0
|
||||
|
||||
index_lut1[attacker][attacked][0] = DIMENSIONS if excluded else feature
|
||||
index_lut1[attacker][attacked][1] = (
|
||||
DIMENSIONS if (excluded or semi_excluded) else feature
|
||||
)
|
||||
|
||||
|
||||
# FEN parsing without chess module
|
||||
def parse_fen(fen):
|
||||
"""Parse FEN string and return board as list of 64 squares"""
|
||||
parts = fen.split()
|
||||
board_str = parts[0]
|
||||
|
||||
board = [None] * 64 # 0 = empty
|
||||
row = 7 # Start from rank 8 (index 7)
|
||||
col = 0
|
||||
|
||||
for char in board_str:
|
||||
if char == "/":
|
||||
row -= 1
|
||||
col = 0
|
||||
elif char.isdigit():
|
||||
col += int(char)
|
||||
else:
|
||||
sq_idx = row * 8 + col
|
||||
# Map piece symbols to piece type (0 = pawn, 1 = knight, etc.)
|
||||
piece_type_map = {
|
||||
"P": (0, True), # White pawn
|
||||
"N": (1, True), # White knight
|
||||
"B": (2, True), # White bishop
|
||||
"R": (3, True), # White rook
|
||||
"Q": (4, True), # White queen
|
||||
"K": (5, True), # White king
|
||||
"p": (0, False), # Black pawn
|
||||
"n": (1, False), # Black knight
|
||||
"b": (2, False), # Black bishop
|
||||
"r": (3, False), # Black rook
|
||||
"q": (4, False), # Black queen
|
||||
"k": (5, False), # Black king
|
||||
}
|
||||
piece_type, is_white = piece_type_map[char]
|
||||
board[sq_idx] = {"type": piece_type, "color": is_white}
|
||||
col += 1
|
||||
|
||||
return board, parts[1]
|
||||
|
||||
|
||||
def fen_to_features(fen: str) -> list:
|
||||
"""Convert FEN to 61,072 feature vector using EXACT Stockfish NNUE encoding."""
|
||||
features = [0.0] * TOTAL_FEATURES
|
||||
b = chess_board(fen)
|
||||
ksq = next((sq for sq in range(64) if b.piece_at(sq) and b.piece_at(sq).unicode_symbol() in ("\u265a", "\u2654")), None)
|
||||
flip = 56 * int(b.turn)
|
||||
features = [0.0] * 61072
|
||||
|
||||
# HalfKAv2_hm features (352)
|
||||
board, turn = parse_fen(fen)
|
||||
|
||||
# Find king position
|
||||
ksq = None
|
||||
for sq in range(64):
|
||||
piece = board[sq]
|
||||
if piece and piece["type"] == 5:
|
||||
ksq = sq
|
||||
break
|
||||
|
||||
# Determine perspective (0 = white, 1 = black)
|
||||
perspective = 0 if turn == "w" else 1
|
||||
|
||||
# Compute orientation from king position
|
||||
orientation = OrientTBL[ksq] ^ (56 * perspective) if ksq else 0
|
||||
|
||||
# HalfKAv2_hm features
|
||||
for piece_sq in range(64):
|
||||
piece = b.piece_at(piece_sq)
|
||||
piece = board[piece_sq]
|
||||
if piece is None:
|
||||
continue
|
||||
piece_type = 5 - piece.piece_type
|
||||
|
||||
piece_type = piece["type"]
|
||||
if piece_type < 0 or piece_type > 5:
|
||||
continue
|
||||
|
||||
oriented_sq = piece_sq ^ int(OrientTBL[ksq]) ^ flip if ksq else piece_sq
|
||||
king_bucket = KingBuckets[ksq ^ flip] if ksq else 0
|
||||
feature_idx = oriented_sq + piece_type + king_bucket
|
||||
from_sq = piece_sq
|
||||
from_oriented = from_sq ^ orientation if ksq else from_sq
|
||||
|
||||
king_bucket = KingBuckets[ksq] if ksq else 0
|
||||
|
||||
feature_idx = from_oriented + piece_type + king_bucket
|
||||
|
||||
if 0 <= feature_idx < 352:
|
||||
features[feature_idx] = 1.0
|
||||
|
||||
# FullThreats features (60,720)
|
||||
# FullThreats features - EXACT Stockfish implementation
|
||||
occupied = 0
|
||||
for sq in range(64):
|
||||
piece = b.piece_at(sq)
|
||||
piece = board[sq]
|
||||
if piece is not None:
|
||||
occupied |= 1 << sq
|
||||
|
||||
# Iterate over all squares to find attacking pieces
|
||||
for from_sq in range(64):
|
||||
piece = board[from_sq]
|
||||
if piece is None:
|
||||
continue
|
||||
attacks_bb = b.attacks(piece.piece_type)
|
||||
|
||||
for to_sq in range(64):
|
||||
if attacks_bb & (1 << to_sq):
|
||||
to_piece = b.piece_at(to_sq)
|
||||
if to_piece is None:
|
||||
continue
|
||||
attacker_type = piece["type"]
|
||||
color = piece["color"]
|
||||
attacker = attacker_type + (color * 6) # Full piece index (0-11)
|
||||
|
||||
to_type = 5 - to_piece.piece_type
|
||||
if to_type < 0 or to_type > 5:
|
||||
continue
|
||||
if attacker_type < 0 or attacker_type > 5:
|
||||
continue
|
||||
|
||||
from_oriented = int(sq ^ int(OrientTBL[ksq]) ^ flip) if ksq else sq
|
||||
to_oriented = int(to_sq ^ int(OrientTBL[ksq]) ^ flip) if ksq else to_sq
|
||||
from_less_than_to = int(from_oriented < to_oriented)
|
||||
from_oriented = from_sq ^ orientation if ksq else from_sq
|
||||
|
||||
lut1_val = int(index_lut1[piece_type][to_type][from_less_than_to])
|
||||
lut2_val = int(index_lut2[piece_type][from_oriented][to_oriented])
|
||||
feature_idx = lut1_val + lut2_val
|
||||
# Get attacks based on piece type
|
||||
attacks_bb = 0
|
||||
if attacker_type == 0: # Pawn
|
||||
color = piece["color"]
|
||||
attacks_bb = (
|
||||
pawnpushorattacks_white(from_sq)
|
||||
if color
|
||||
else pawnpushorattacks_black(from_sq)
|
||||
)
|
||||
elif attacker_type == 1: # Knight
|
||||
attacks_bb = pseudoattacks_knight(from_sq)
|
||||
elif attacker_type == 2: # Bishop
|
||||
attacks_bb = pseudoattacks_bishop(from_sq)
|
||||
elif attacker_type == 3: # Rook
|
||||
attacks_bb = pseudoattacks_rook(from_sq)
|
||||
elif attacker_type == 4: # Queen
|
||||
attacks_bb = pseudoattacks_queen(from_sq)
|
||||
elif attacker_type == 5: # King
|
||||
attacks_bb = pseudoattacks_king(from_sq)
|
||||
|
||||
if 0 <= feature_idx < 60720:
|
||||
features[feature_idx] = 1.0
|
||||
# Filter to occupied squares only
|
||||
attacks_bb &= occupied
|
||||
|
||||
# Iterate over attacked squares
|
||||
while attacks_bb:
|
||||
to_sq = (attacks_bb & -attacks_bb).bit_length() - 1
|
||||
attacks_bb &= attacks_bb - 1
|
||||
|
||||
attacked = board[to_sq]
|
||||
if attacked is None:
|
||||
continue
|
||||
|
||||
attacked_type = attacked["type"]
|
||||
color = attacked["color"]
|
||||
attacked = attacked_type + (color * 6) # Full piece index (0-11)
|
||||
|
||||
if attacked_type < 0 or attacked_type > 5:
|
||||
continue
|
||||
|
||||
to_oriented = to_sq ^ orientation if ksq else to_sq
|
||||
from_less_than_to = 1 if from_oriented < to_oriented else 0
|
||||
|
||||
# Compute feature index using exact Stockfish formula
|
||||
lut1_val = index_lut1[attacker][attacked][from_less_than_to]
|
||||
lut2_val = index_lut2[attacker][from_oriented][to_oriented]
|
||||
offset_val = offsets[attacker][from_oriented]
|
||||
|
||||
feature_idx = lut1_val + offset_val + lut2_val
|
||||
|
||||
if 0 <= feature_idx < DIMENSIONS:
|
||||
features[feature_idx] = 1.0
|
||||
|
||||
# Add pawn push features - pawns blocked by another pawn
|
||||
# From Stockfish's full_threats.cpp
|
||||
for color in [True, False]: # White and Black
|
||||
# Find pawns of this color
|
||||
pushers = 0
|
||||
for sq in range(64):
|
||||
piece = board[sq]
|
||||
if piece and piece["type"] == 0 and piece["color"] == color:
|
||||
# Check if there's a pawn in front
|
||||
target_sq = sq + 8 if color else sq - 8
|
||||
if target_sq >= 0 and target_sq < 64:
|
||||
target_piece = board[target_sq]
|
||||
if (
|
||||
target_piece
|
||||
and target_piece["type"] == 0
|
||||
and target_piece["color"] == color
|
||||
):
|
||||
# This pawn is blocked, add push feature
|
||||
pushers |= 1 << sq
|
||||
|
||||
while pushers:
|
||||
from_sq = (pushers & -pushers).bit_length() - 1
|
||||
pushers &= pushers - 1
|
||||
|
||||
piece = board[from_sq]
|
||||
if piece is None:
|
||||
continue
|
||||
|
||||
attacker_type = piece["type"]
|
||||
if attacker_type < 0 or attacker_type > 5:
|
||||
continue
|
||||
|
||||
from_oriented = from_sq ^ orientation if ksq else from_sq
|
||||
|
||||
# Compute target square
|
||||
to_sq = from_sq + 8 if color else from_sq - 8
|
||||
|
||||
attacked = board[to_sq]
|
||||
if attacked is None or attacked["type"] != 0:
|
||||
continue
|
||||
|
||||
attacked_type = attacked["type"]
|
||||
if attacked_type < 0 or attacked_type > 5:
|
||||
continue
|
||||
|
||||
to_oriented = to_sq ^ orientation if ksq else to_sq
|
||||
from_less_than_to = 1 if from_oriented < to_oriented else 0
|
||||
|
||||
# Compute feature index
|
||||
lut1_val = index_lut1[attacker][attacked_type][from_less_than_to]
|
||||
lut2_val = index_lut2[attacker][from_oriented][to_oriented]
|
||||
offset_val = offsets[attacker][from_oriented]
|
||||
|
||||
feature_idx = lut1_val + offset_val + lut2_val
|
||||
|
||||
if 0 <= feature_idx < DIMENSIONS:
|
||||
features[feature_idx] = 1.0
|
||||
|
||||
return features
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'
|
||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
||||
features = fen_to_features(fen)
|
||||
print(f"Features: {sum(features)}")
|
||||
active = sum(1 for v in features if v > 0)
|
||||
print(f"Active features on starting position: {active}")
|
||||
print(f"Feature vector length: {len(features)}")
|
||||
|
||||
Reference in New Issue
Block a user