"""Tests for NNUE feature extraction""" import pytest import torch import numpy as np from python.model.feature_extractor import fen_to_features from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES class TestFeatureExtraction: """Tests for HalfKAv2_hm feature extraction""" def test_feature_count(self): """Test total feature vector length""" fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" features = fen_to_features(fen) assert len(features) == TOTAL_FEATURES def test_half_ka_v2_hm_features(self): """Test HalfKAv2_hm + FullThreats produces correct number of features""" fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" features = fen_to_features(fen) active = sum(1 for v in features if v > 0) # HalfKAv2_hm: 24 pieces + 1 king bucket = 25 features # FullThreats: ~79 features (piece-pair attack relationships) # Total: ~103 features assert 100 <= active <= 110 # Allow for slight variations def test_feature_range(self): """Test all features are in valid range""" fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" features = fen_to_features(fen) assert all(0 <= f <= 1 for f in features) def test_black_perspective(self): """Test feature extraction from black's perspective""" fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1" features = fen_to_features(fen) active = sum(features) assert active > 20 # Multiple pieces from black's perspective def test_mixed_colors(self): """Test feature extraction with both colors on board""" fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing features = fen_to_features(fen) active = sum(1 for v in features if v > 0) assert active < 100 # Fewer pieces than full board (~103) def test_zero_features_empty_board(self): """Test empty board produces zero features""" fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1" features = fen_to_features(fen) assert sum(features) == 0 def test_tensor_conversion(self): """Test conversion to torch tensor""" fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" features = fen_to_features(fen) tensor = torch.tensor(features, dtype=torch.float32) assert tensor.shape == (TOTAL_FEATURES,)