- Use oriented squares for piece encoding - 24 pieces + 1 king bucket = 25 active features on starting position - King bucket features prefer white king perspective - All tests passing (11 tests)
59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
"""Tests for NNUE feature extraction"""
|
|
|
|
import pytest
|
|
import torch
|
|
import numpy as np
|
|
from python.model.feature_extractor import fen_to_features
|
|
from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES
|
|
|
|
|
|
class TestFeatureExtraction:
|
|
"""Tests for HalfKAv2_hm feature extraction"""
|
|
|
|
def test_feature_count(self):
|
|
"""Test total feature vector length"""
|
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
features = fen_to_features(fen)
|
|
assert len(features) == TOTAL_FEATURES
|
|
|
|
def test_half_ka_v2_hm_features(self):
|
|
"""Test HalfKAv2_hm produces correct number of features"""
|
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
features = fen_to_features(fen)
|
|
active = sum(features)
|
|
# HalfKAv2_hm: 24 pieces + 1 king bucket = 25 features
|
|
assert active == 25
|
|
|
|
def test_feature_range(self):
|
|
"""Test all features are in valid range"""
|
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
features = fen_to_features(fen)
|
|
assert all(0 <= f <= 1 for f in features)
|
|
|
|
def test_black_perspective(self):
|
|
"""Test feature extraction from black's perspective"""
|
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
|
|
features = fen_to_features(fen)
|
|
active = sum(features)
|
|
assert active > 20 # Multiple pieces from black's perspective
|
|
|
|
def test_mixed_colors(self):
|
|
"""Test feature extraction with both colors on board"""
|
|
fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing
|
|
features = fen_to_features(fen)
|
|
active = sum(features)
|
|
assert active <= 30 # Fewer pieces
|
|
|
|
def test_zero_features_empty_board(self):
|
|
"""Test empty board produces zero features"""
|
|
fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1"
|
|
features = fen_to_features(fen)
|
|
assert sum(features) == 0
|
|
|
|
def test_tensor_conversion(self):
|
|
"""Test conversion to torch tensor"""
|
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
features = fen_to_features(fen)
|
|
tensor = torch.tensor(features, dtype=torch.float32)
|
|
assert tensor.shape == (TOTAL_FEATURES,)
|