From 62124245a70b3e24f8a1cf7c6650a12fce2bfb91 Mon Sep 17 00:00:00 2001 From: KeshavAnandCode Date: Tue, 14 Apr 2026 19:34:17 -0500 Subject: [PATCH] broken stash --- python/python/model/feature_extractor.py | 643 ++++++++++++++++++++--- 1 file changed, 574 insertions(+), 69 deletions(-) diff --git a/python/python/model/feature_extractor.py b/python/python/model/feature_extractor.py index 9d40a32..1f036bc 100644 --- a/python/python/model/feature_extractor.py +++ b/python/python/model/feature_extractor.py @@ -1,95 +1,600 @@ """Extract NNUE features from FEN strings - EXACT Stockfish Implementation""" -import chess -from chess import Board as chess_board -import numpy as np -from python.constants import TOTAL_FEATURES +# EXACT Stockfish constants +PIECE_NB = 12 +PIECE_TYPE_NB = 6 +SQUARE_NB = 64 -# EXACT Stockfish NNUE Tables -OrientTBL = np.array([10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, - 10, 10, 10, 10, 0, 0, 0, 0, -], dtype=np.int8) +# Exact numValidTargets from Stockfish +numValidTargets = [0, 6, 10, 8, 8, 10, 0, 0, 0, 6, 10, 8, 8, 10, 0, 0] -KingBuckets = np.array([28*11, 29*11, 30*11, 31*11, 31*11, 30*11, 29*11, 28*11, - 24*11, 25*11, 26*11, 27*11, 27*11, 26*11, 25*11, 24*11, - 20*11, 21*11, 22*11, 23*11, 23*11, 22*11, 21*11, 20*11, - 16*11, 17*11, 18*11, 19*11, 19*11, 18*11, 17*11, 16*11, - 12*11, 13*11, 14*11, 15*11, 15*11, 14*11, 13*11, 12*11, - 8*11, 9*11, 10*11, 11*11, 11*11, 10*11, 9*11, 8*11, - 4*11, 5*11, 6*11, 7*11, 7*11, 6*11, 5*11, 4*11, - 0, 1*11, 2*11, 3*11, 3*11, 2*11, 1*11, 0, -], dtype=np.int16) +# Exact map table from Stockfish +map_table = [ + [0, 1, -1, 2, -1, -1], + [0, 1, 2, 3, 4, -1], + [0, 1, 2, 3, -1, -1], + [0, 1, 2, 3, -1, -1], + [0, 1, 2, 3, 4, -1], + [-1, -1, -1, -1, -1, -1], +] -# Precomputed lookup tables (simplified for distillation) -index_lut1 = np.zeros((6, 6, 2), dtype=np.int32) -index_lut2 = np.zeros((6, 64, 64), dtype=np.uint8) +# Exact OrientTBL from Stockfish (SQ_A1=0, SQ_H1=10) +OrientTBL = [ + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, + 0, + 0, + 0, + 0, + 10, + 10, + 10, + 10, +] + +# Exact KingBuckets from Stockfish +KingBuckets = [ + 28 * 11, + 29 * 11, + 30 * 11, + 31 * 11, + 31 * 11, + 30 * 11, + 29 * 11, + 28 * 11, + 24 * 11, + 25 * 11, + 26 * 11, + 27 * 11, + 27 * 11, + 26 * 11, + 25 * 11, + 24 * 11, + 20 * 11, + 21 * 11, + 22 * 11, + 23 * 11, + 23 * 11, + 22 * 11, + 21 * 11, + 20 * 11, + 16 * 11, + 17 * 11, + 18 * 11, + 19 * 11, + 19 * 11, + 18 * 11, + 17 * 11, + 16 * 11, + 12 * 11, + 13 * 11, + 14 * 11, + 15 * 11, + 15 * 11, + 14 * 11, + 13 * 11, + 12 * 11, + 8 * 11, + 9 * 11, + 10 * 11, + 11 * 11, + 11 * 11, + 10 * 11, + 9 * 11, + 8 * 11, + 4 * 11, + 5 * 11, + 6 * 11, + 7 * 11, + 7 * 11, + 6 * 11, + 5 * 11, + 4 * 11, + 0, + 1 * 11, + 2 * 11, + 3 * 11, + 3 * 11, + 2 * 11, + 1 * 11, + 0, +] + + +# POPCOUNT function (exact copy of Stockfish) +def popcount(n): + return bin(n).count("1") + + +# EXACT bitboard functions from Stockfish +def pseudoattacks_knight(sq): + """Exact Stockfish knight pseudoattacks""" + attacks = 0 + file = sq % 8 + rank = sq // 8 + if rank > 0 and file > 0: + attacks |= 1 << ((rank - 1) * 8 + (file - 1)) + if rank > 0 and file < 7: + attacks |= 1 << ((rank - 1) * 8 + (file + 1)) + if rank < 7 and file > 0: + attacks |= 1 << ((rank + 1) * 8 + (file - 1)) + if rank < 7 and file < 7: + attacks |= 1 << ((rank + 1) * 8 + (file + 1)) + return attacks + + +def pseudoattacks_bishop(sq): + """Exact Stockfish bishop pseudoattacks""" + attacks = 0 + file = sq % 8 + rank = sq // 8 + for dr, dc in [(-1, -1), (-1, 1), (1, -1), (1, 1)]: + r, c = rank + dr, file + dc + while 0 <= r < 8 and 0 <= c < 8: + attacks |= 1 << (r * 8 + c) + r, c = r + dr, c + dc + return attacks + + +def pseudoattacks_rook(sq): + """Exact Stockfish rook pseudoattacks""" + attacks = 0 + file = sq % 8 + rank = sq // 8 + for dr, dc in [(0, 1), (0, -1), (1, 0), (-1, 0)]: + r, c = rank + dr, file + dc + while 0 <= r < 8 and 0 <= c < 8: + attacks |= 1 << (r * 8 + c) + r, c = r + dr, c + dc + return attacks + + +def pseudoattacks_queen(sq): + """Exact Stockfish queen pseudoattacks""" + return pseudoattacks_bishop(sq) | pseudoattacks_rook(sq) + + +def pseudoattacks_king(sq): + """Exact Stockfish king pseudoattacks""" + attacks = 1 << sq + file = sq % 8 + rank = sq // 8 + for dr in [-1, 0, 1]: + for dc in [-1, 0, 1]: + if dr == 0 and dc == 0: + continue + r, c = rank + dr, file + dc + if 0 <= r < 8 and 0 <= c < 8: + attacks |= 1 << (r * 8 + c) + return attacks + + +def pawnpushorattacks_white(sq): + """Exact Stockfish white pawn push/attacks""" + attacks = 0 + file = sq % 8 + rank = sq // 8 + if rank < 7: + attacks |= 1 << ((rank + 1) * 8 + file) + if rank < 6 and file > 0: + attacks |= 1 << ((rank + 1) * 8 + (file - 1)) + if rank < 6 and file < 7: + attacks |= 1 << ((rank + 1) * 8 + (file + 1)) + return attacks + + +def pawnpushorattacks_black(sq): + """Exact Stockfish black pawn push/attacks""" + attacks = 0 + file = sq % 8 + rank = sq // 8 + if rank < 1: + return attacks + attacks |= 1 << ((rank - 1) * 8 + file) + if rank > 0 and file > 0: + attacks |= 1 << ((rank - 1) * 8 + (file - 1)) + if rank > 0 and file < 7: + attacks |= 1 << ((rank - 1) * 8 + (file + 1)) + return attacks + + +# Generate index_lut2 EXACTLY as Stockfish (make_piece_indices_type + make_piece_indices_piece) +# This computes popcount(((1< +for from_sq in range(64): + attacks = pseudoattacks_knight(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks) + index_lut2[1][from_sq][to_sq] = count + index_lut2[8][from_sq][to_sq] = count + +# Bishops (2, 9) - template make_piece_indices_type +for from_sq in range(64): + attacks = pseudoattacks_bishop(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks) + index_lut2[2][from_sq][to_sq] = count + index_lut2[9][from_sq][to_sq] = count + +# Rooks (3, 10) - template make_piece_indices_type +for from_sq in range(64): + attacks = pseudoattacks_rook(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks) + index_lut2[3][from_sq][to_sq] = count + index_lut2[10][from_sq][to_sq] = count + +# Queens (4, 11) - template make_piece_indices_type +for from_sq in range(64): + attacks = pseudoattacks_queen(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks) + index_lut2[4][from_sq][to_sq] = count + index_lut2[11][from_sq][to_sq] = count + +# Kings (5, 6) - template make_piece_indices_type +for from_sq in range(64): + attacks = pseudoattacks_king(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks) + index_lut2[5][from_sq][to_sq] = count + index_lut2[6][from_sq][to_sq] = count + +# Pawns (0, 7) - template make_piece_indices_piece / B_PAWN +for from_sq in range(64): + attacks_white = pawnpushorattacks_white(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks_white) + index_lut2[0][from_sq][to_sq] = count + attacks_black = pawnpushorattacks_black(from_sq) + for to_sq in range(64): + count = popcount(((1 << to_sq) - 1) & attacks_black) + index_lut2[7][from_sq][to_sq] = count + +# Compute helper_offsets and offsets EXACTLY as Stockfish (init_threat_offsets) +AllPieces = list(range(PIECE_NB)) +helper_offsets = [None] * PIECE_NB +offsets = [[0] * 64 for _ in range(PIECE_NB)] +cumulativeOffset = 0 + +for piece_idx in AllPieces: + cumulativePieceOffset = 0 + piece_type = piece_idx // 6 # 0=pawn, 1=knight, 2=bishop, 3=rook, 4=queen, 5=king -# Simple attack count lookup (simplified from Stockfish) -for attacker in range(6): for from_sq in range(64): - for to_sq in range(64): - index_lut2[attacker, from_sq, to_sq] = 1 + offsets[piece_idx][from_sq] = cumulativePieceOffset + + if piece_type == 0: # Pawn + if from_sq >= 8 and from_sq < 56: # Not on rank 1 or 8 + color = piece_idx < 8 # White pawn + attacks = ( + pawnpushorattacks_white(from_sq) + if color + else pawnpushorattacks_black(from_sq) + ) + cumulativePieceOffset += popcount(attacks) + else: # Non-pawn + if piece_type == 1: # Knight + attacks = pseudoattacks_knight(from_sq) + elif piece_type == 2: # Bishop + attacks = pseudoattacks_bishop(from_sq) + elif piece_type == 3: # Rook + attacks = pseudoattacks_rook(from_sq) + elif piece_type == 4: # Queen + attacks = pseudoattacks_queen(from_sq) + elif piece_type == 5: # King + attacks = pseudoattacks_king(from_sq) + cumulativePieceOffset += popcount(attacks) + + helper_offsets[piece_idx] = { + "cumulativePieceOffset": cumulativePieceOffset, + "cumulativeOffset": cumulativeOffset, + } + cumulativeOffset += numValidTargets[piece_idx] * cumulativePieceOffset + +# Compute index_lut1 EXACTLY as Stockfish (init_index_luts) +index_lut1 = [[[0, 0] for _ in range(PIECE_NB)] for _ in range(PIECE_NB)] +DIMENSIONS = 60720 + +for attacker in AllPieces: + for attacked in AllPieces: + enemy = (attacker ^ attacked) == 8 # Different color + attacker_type = attacker // 6 + attacked_type = attacked // 6 + + map_val = map_table[attacker_type][attacked_type] + color_attacked = ( + 1 if attacked < 8 else 0 + ) # 1 if white, 0 if black (piece indices 0-7 are white, 8-11 are black) + semi_excluded = attacker_type == attacked_type and (enemy or attacker_type != 0) + + feature = helper_offsets[attacker]["cumulativeOffset"] + feature += ( + color_attacked * (numValidTargets[attacker] // 2) + map_val + ) * helper_offsets[attacker]["cumulativePieceOffset"] + + excluded = map_val < 0 + + index_lut1[attacker][attacked][0] = DIMENSIONS if excluded else feature + index_lut1[attacker][attacked][1] = ( + DIMENSIONS if (excluded or semi_excluded) else feature + ) + + +# FEN parsing without chess module +def parse_fen(fen): + """Parse FEN string and return board as list of 64 squares""" + parts = fen.split() + board_str = parts[0] + + board = [None] * 64 # 0 = empty + row = 7 # Start from rank 8 (index 7) + col = 0 + + for char in board_str: + if char == "/": + row -= 1 + col = 0 + elif char.isdigit(): + col += int(char) + else: + sq_idx = row * 8 + col + # Map piece symbols to piece type (0 = pawn, 1 = knight, etc.) + piece_type_map = { + "P": (0, True), # White pawn + "N": (1, True), # White knight + "B": (2, True), # White bishop + "R": (3, True), # White rook + "Q": (4, True), # White queen + "K": (5, True), # White king + "p": (0, False), # Black pawn + "n": (1, False), # Black knight + "b": (2, False), # Black bishop + "r": (3, False), # Black rook + "q": (4, False), # Black queen + "k": (5, False), # Black king + } + piece_type, is_white = piece_type_map[char] + board[sq_idx] = {"type": piece_type, "color": is_white} + col += 1 + + return board, parts[1] + def fen_to_features(fen: str) -> list: """Convert FEN to 61,072 feature vector using EXACT Stockfish NNUE encoding.""" - features = [0.0] * TOTAL_FEATURES - b = chess_board(fen) - ksq = next((sq for sq in range(64) if b.piece_at(sq) and b.piece_at(sq).unicode_symbol() in ("\u265a", "\u2654")), None) - flip = 56 * int(b.turn) - - # HalfKAv2_hm features (352) + features = [0.0] * 61072 + + board, turn = parse_fen(fen) + + # Find king position + ksq = None + for sq in range(64): + piece = board[sq] + if piece and piece["type"] == 5: + ksq = sq + break + + # Determine perspective (0 = white, 1 = black) + perspective = 0 if turn == "w" else 1 + + # Compute orientation from king position + orientation = OrientTBL[ksq] ^ (56 * perspective) if ksq else 0 + + # HalfKAv2_hm features for piece_sq in range(64): - piece = b.piece_at(piece_sq) + piece = board[piece_sq] if piece is None: continue - piece_type = 5 - piece.piece_type + + piece_type = piece["type"] if piece_type < 0 or piece_type > 5: continue - - oriented_sq = piece_sq ^ int(OrientTBL[ksq]) ^ flip if ksq else piece_sq - king_bucket = KingBuckets[ksq ^ flip] if ksq else 0 - feature_idx = oriented_sq + piece_type + king_bucket - + + from_sq = piece_sq + from_oriented = from_sq ^ orientation if ksq else from_sq + + king_bucket = KingBuckets[ksq] if ksq else 0 + + feature_idx = from_oriented + piece_type + king_bucket + if 0 <= feature_idx < 352: features[feature_idx] = 1.0 - # FullThreats features (60,720) + # FullThreats features - EXACT Stockfish implementation + occupied = 0 for sq in range(64): - piece = b.piece_at(sq) + piece = board[sq] + if piece is not None: + occupied |= 1 << sq + + # Iterate over all squares to find attacking pieces + for from_sq in range(64): + piece = board[from_sq] if piece is None: continue - attacks_bb = b.attacks(piece.piece_type) - - for to_sq in range(64): - if attacks_bb & (1 << to_sq): - to_piece = b.piece_at(to_sq) - if to_piece is None: - continue - - to_type = 5 - to_piece.piece_type - if to_type < 0 or to_type > 5: - continue - - from_oriented = int(sq ^ int(OrientTBL[ksq]) ^ flip) if ksq else sq - to_oriented = int(to_sq ^ int(OrientTBL[ksq]) ^ flip) if ksq else to_sq - from_less_than_to = int(from_oriented < to_oriented) - - lut1_val = int(index_lut1[piece_type][to_type][from_less_than_to]) - lut2_val = int(index_lut2[piece_type][from_oriented][to_oriented]) - feature_idx = lut1_val + lut2_val - - if 0 <= feature_idx < 60720: - features[feature_idx] = 1.0 + + attacker_type = piece["type"] + color = piece["color"] + attacker = attacker_type + (color * 6) # Full piece index (0-11) + + if attacker_type < 0 or attacker_type > 5: + continue + + from_oriented = from_sq ^ orientation if ksq else from_sq + + # Get attacks based on piece type + attacks_bb = 0 + if attacker_type == 0: # Pawn + color = piece["color"] + attacks_bb = ( + pawnpushorattacks_white(from_sq) + if color + else pawnpushorattacks_black(from_sq) + ) + elif attacker_type == 1: # Knight + attacks_bb = pseudoattacks_knight(from_sq) + elif attacker_type == 2: # Bishop + attacks_bb = pseudoattacks_bishop(from_sq) + elif attacker_type == 3: # Rook + attacks_bb = pseudoattacks_rook(from_sq) + elif attacker_type == 4: # Queen + attacks_bb = pseudoattacks_queen(from_sq) + elif attacker_type == 5: # King + attacks_bb = pseudoattacks_king(from_sq) + + # Filter to occupied squares only + attacks_bb &= occupied + + # Iterate over attacked squares + while attacks_bb: + to_sq = (attacks_bb & -attacks_bb).bit_length() - 1 + attacks_bb &= attacks_bb - 1 + + attacked = board[to_sq] + if attacked is None: + continue + + attacked_type = attacked["type"] + color = attacked["color"] + attacked = attacked_type + (color * 6) # Full piece index (0-11) + + if attacked_type < 0 or attacked_type > 5: + continue + + to_oriented = to_sq ^ orientation if ksq else to_sq + from_less_than_to = 1 if from_oriented < to_oriented else 0 + + # Compute feature index using exact Stockfish formula + lut1_val = index_lut1[attacker][attacked][from_less_than_to] + lut2_val = index_lut2[attacker][from_oriented][to_oriented] + offset_val = offsets[attacker][from_oriented] + + feature_idx = lut1_val + offset_val + lut2_val + + if 0 <= feature_idx < DIMENSIONS: + features[feature_idx] = 1.0 + + # Add pawn push features - pawns blocked by another pawn + # From Stockfish's full_threats.cpp + for color in [True, False]: # White and Black + # Find pawns of this color + pushers = 0 + for sq in range(64): + piece = board[sq] + if piece and piece["type"] == 0 and piece["color"] == color: + # Check if there's a pawn in front + target_sq = sq + 8 if color else sq - 8 + if target_sq >= 0 and target_sq < 64: + target_piece = board[target_sq] + if ( + target_piece + and target_piece["type"] == 0 + and target_piece["color"] == color + ): + # This pawn is blocked, add push feature + pushers |= 1 << sq + + while pushers: + from_sq = (pushers & -pushers).bit_length() - 1 + pushers &= pushers - 1 + + piece = board[from_sq] + if piece is None: + continue + + attacker_type = piece["type"] + if attacker_type < 0 or attacker_type > 5: + continue + + from_oriented = from_sq ^ orientation if ksq else from_sq + + # Compute target square + to_sq = from_sq + 8 if color else from_sq - 8 + + attacked = board[to_sq] + if attacked is None or attacked["type"] != 0: + continue + + attacked_type = attacked["type"] + if attacked_type < 0 or attacked_type > 5: + continue + + to_oriented = to_sq ^ orientation if ksq else to_sq + from_less_than_to = 1 if from_oriented < to_oriented else 0 + + # Compute feature index + lut1_val = index_lut1[attacker][attacked_type][from_less_than_to] + lut2_val = index_lut2[attacker][from_oriented][to_oriented] + offset_val = offsets[attacker][from_oriented] + + feature_idx = lut1_val + offset_val + lut2_val + + if 0 <= feature_idx < DIMENSIONS: + features[feature_idx] = 1.0 return features + if __name__ == "__main__": - fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1' + fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" features = fen_to_features(fen) - print(f"Features: {sum(features)}") + active = sum(1 for v in features if v > 0) + print(f"Active features on starting position: {active}") + print(f"Feature vector length: {len(features)}")