"""Tests for NNUE implementation""" import pytest import torch import numpy as np from python.model.nnue_linear import LinearEval from python.constants import TOTAL_FEATURES class TestLinearEval: """Tests for the linear NNUE model""" def test_model_initialization(self): """Test model creates correct shape""" model = LinearEval() assert model.linear.in_features == TOTAL_FEATURES assert model.linear.out_features == 1 def test_model_output_shape(self): """Test model outputs correct shape""" model = LinearEval() x = torch.randn(10, TOTAL_FEATURES) y = model(x) assert y.shape == (10, 1) def test_model_zero_output(self): """Test model with zero input""" model = LinearEval() x = torch.zeros(1, TOTAL_FEATURES) with torch.no_grad(): y = model(x) assert y.item() == 0.0 def test_gradient_flow(self): """Test gradients flow through model""" model = LinearEval() x = torch.randn(10, TOTAL_FEATURES, requires_grad=True) y = model(x) loss = y.sum() loss.backward() assert x.grad is not None if __name__ == "__main__": pytest.main([__file__, "-v"])