227 lines
7.6 KiB
Python
227 lines
7.6 KiB
Python
import sqllogictest
|
|
from sqllogictest import SQLParserException, SQLLogicParser, SQLLogicTest
|
|
import duckdb
|
|
from typing import Optional
|
|
import argparse
|
|
import shutil
|
|
import os
|
|
import subprocess
|
|
|
|
# example usage: python3 scripts/test_serialization_bwc.py --old-source ../duckdb-bugfix --test-file test/sql/aggregate/aggregates/test_median.test
|
|
|
|
serialized_path = os.path.join('test', 'api', 'serialized_plans')
|
|
db_load_path = os.path.join(serialized_path, 'db_load.sql')
|
|
queries_path = os.path.join(serialized_path, 'queries.sql')
|
|
result_binary = os.path.join(serialized_path, 'serialized_plans.binary')
|
|
unittest_binary = os.path.join('build', 'debug', 'test', 'unittest')
|
|
|
|
|
|
def complete_query(q):
|
|
q = q.strip()
|
|
if q.endswith(';'):
|
|
return q
|
|
return q + ';'
|
|
|
|
|
|
def parse_test_file(filename):
|
|
parser = SQLLogicParser()
|
|
try:
|
|
out: Optional[SQLLogicTest] = parser.parse(filename)
|
|
if not out:
|
|
raise SQLParserException(f"Test {filename} could not be parsed")
|
|
except:
|
|
return {'load': [], 'query': []}
|
|
loop_count = 0
|
|
load_statements = []
|
|
query_statements = []
|
|
for stmt in out.statements:
|
|
if type(stmt) is sqllogictest.statement.skip.Skip:
|
|
# mode skip - just skip entire test
|
|
break
|
|
if type(stmt) is sqllogictest.statement.loop.Loop or type(stmt) is sqllogictest.statement.foreach.Foreach:
|
|
loop_count += 1
|
|
if type(stmt) is sqllogictest.statement.endloop.Endloop:
|
|
loop_count -= 1
|
|
if loop_count > 0:
|
|
# loops are ignored currently
|
|
continue
|
|
if not (
|
|
type(stmt) is sqllogictest.statement.query.Query or type(stmt) is sqllogictest.statement.statement.Statement
|
|
):
|
|
# only handle query and statement nodes for now
|
|
continue
|
|
if type(stmt) is sqllogictest.statement.statement.Statement:
|
|
# skip expected errors
|
|
if stmt.expected_result.type == sqllogictest.ExpectedResult.Type.ERROR:
|
|
continue
|
|
query = ' '.join(stmt.lines)
|
|
try:
|
|
sql_stmt_list = duckdb.extract_statements(query)
|
|
except KeyboardInterrupt:
|
|
raise
|
|
except:
|
|
continue
|
|
for sql_stmt in sql_stmt_list:
|
|
if sql_stmt.type == duckdb.StatementType.SELECT:
|
|
query_statements.append(query)
|
|
elif sql_stmt.type == duckdb.StatementType.PRAGMA:
|
|
continue
|
|
else:
|
|
load_statements.append(query)
|
|
return {'load': load_statements, 'query': query_statements}
|
|
|
|
|
|
def build_sources(old_source, new_source):
|
|
# generate the sources
|
|
current_path = os.getcwd()
|
|
os.chdir(old_source)
|
|
# build if not yet build
|
|
if not os.path.isfile(unittest_binary):
|
|
res = subprocess.run(['make', 'debug']).returncode
|
|
if res != 0:
|
|
raise Exception("Failed to build old sources")
|
|
|
|
# run the verification
|
|
os.chdir(current_path)
|
|
os.chdir(new_source)
|
|
|
|
# build if not yet build
|
|
if not os.path.isfile(unittest_binary):
|
|
res = subprocess.run(['make', 'debug']).returncode
|
|
if res != 0:
|
|
raise Exception("Failed to build new sources")
|
|
os.chdir(current_path)
|
|
|
|
|
|
def run_test(filename, old_source, new_source, no_exit):
|
|
statements = parse_test_file(filename)
|
|
|
|
# generate the sources
|
|
current_path = os.getcwd()
|
|
os.chdir(old_source)
|
|
# write the files
|
|
with open(os.path.join(old_source, db_load_path), 'w+') as f:
|
|
for stmt in statements['load']:
|
|
f.write(complete_query(stmt) + '\n')
|
|
|
|
with open(os.path.join(old_source, queries_path), 'w+') as f:
|
|
for stmt in statements['query']:
|
|
f.write(complete_query(stmt) + '\n')
|
|
|
|
# generate the serialization
|
|
my_env = os.environ.copy()
|
|
my_env['GEN_PLAN_STORAGE'] = '1'
|
|
res = subprocess.run(['build/debug/test/unittest', 'Generate serialized plans file'], env=my_env).returncode
|
|
if res != 0:
|
|
print(f"SKIPPING TEST {filename}")
|
|
return True
|
|
|
|
os.chdir(current_path)
|
|
|
|
# copy over the files
|
|
for f in [db_load_path, queries_path, result_binary]:
|
|
shutil.copy(os.path.join(old_source, f), os.path.join(new_source, f))
|
|
|
|
# run the verification
|
|
os.chdir(new_source)
|
|
|
|
res = subprocess.run(['build/debug/test/unittest', "Test deserialized plans from file"]).returncode
|
|
if res != 0:
|
|
if no_exit:
|
|
print("BROKEN TEST")
|
|
with open('broken_tests.list', 'a') as f:
|
|
f.write(filename + '\n')
|
|
return False
|
|
raise Exception("Deserialization failure")
|
|
os.chdir(current_path)
|
|
return True
|
|
|
|
|
|
def parse_excluded_tests(path):
|
|
exclusion_list = {}
|
|
with open(path) as f:
|
|
for line in f:
|
|
if len(line.strip()) == 0 or line[0] == '#':
|
|
continue
|
|
exclusion_list[line.strip()] = True
|
|
return exclusion_list
|
|
|
|
|
|
def find_tests_recursive(dir, excluded_paths):
|
|
test_list = []
|
|
for f in os.listdir(dir):
|
|
path = os.path.join(dir, f)
|
|
if path in excluded_paths:
|
|
continue
|
|
if os.path.isdir(path):
|
|
test_list += find_tests_recursive(path, excluded_paths)
|
|
elif path.endswith('.test'):
|
|
test_list.append(path)
|
|
return test_list
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Test serialization")
|
|
parser.add_argument("--new-source", type=str, help="Path to the new source", default='.')
|
|
parser.add_argument("--old-source", type=str, help="Path to the old source")
|
|
parser.add_argument("--start-at", type=str, help="Start running tests at this specific test", default=None)
|
|
parser.add_argument("--no-exit", action="store_true", help="Keep running even if a test fails", default=False)
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument("--test-file", type=str, help="Path to the SQL logic file", default='')
|
|
group.add_argument("--all-tests", action='store_true', help="Run all tests", default=False)
|
|
group.add_argument("--test-list", type=str, help="Load tests to run from a file list", default=None)
|
|
args = parser.parse_args()
|
|
|
|
old_source = args.old_source
|
|
new_source = args.new_source
|
|
files = []
|
|
if args.all_tests:
|
|
# run all tests
|
|
excluded_tests = parse_excluded_tests(
|
|
os.path.join(new_source, 'test', 'api', 'serialized_plans', 'excluded_tests.list')
|
|
)
|
|
test_dir = os.path.join('test', 'sql')
|
|
if new_source != '.':
|
|
test_dir = os.path.join(new_source, test_dir)
|
|
files = find_tests_recursive(test_dir, excluded_tests)
|
|
elif args.test_list is not None:
|
|
with open(args.test_list, 'r') as f:
|
|
for line in f:
|
|
if len(line.strip()) == 0:
|
|
continue
|
|
files.append(line.strip())
|
|
else:
|
|
# run a single test
|
|
files.append(args.test_file)
|
|
files.sort()
|
|
|
|
current_path = os.getcwd()
|
|
try:
|
|
build_sources(old_source, new_source)
|
|
|
|
all_succeeded = True
|
|
started = False
|
|
if args.start_at is None:
|
|
started = True
|
|
for filename in files:
|
|
if not started:
|
|
if filename == args.start_at:
|
|
started = True
|
|
else:
|
|
continue
|
|
|
|
print(f"Run test {filename}")
|
|
os.chdir(current_path)
|
|
if not run_test(filename, old_source, new_source, args.no_exit):
|
|
all_succeeded = False
|
|
if not all_succeeded:
|
|
exit(1)
|
|
except:
|
|
raise
|
|
finally:
|
|
os.chdir(current_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|