231 lines
8.8 KiB
Python
231 lines
8.8 KiB
Python
import argparse
|
|
import os
|
|
import sqllogictest
|
|
from sqllogictest import SQLParserException, SQLLogicParser, SQLLogicTest
|
|
import subprocess
|
|
import multiprocessing
|
|
import tempfile
|
|
import re
|
|
|
|
parser = argparse.ArgumentParser(description="Test serialization")
|
|
parser.add_argument("--shell", type=str, help="Shell binary to run", default=os.path.join('build', 'debug', 'duckdb'))
|
|
parser.add_argument("--offset", type=int, help="File offset", default=None)
|
|
parser.add_argument("--count", type=int, help="File count", default=None)
|
|
parser.add_argument('--no-exit', action='store_true', help='Do not exit after a test fails', default=False)
|
|
parser.add_argument('--print-failing-only', action='store_true', help='Print failing tests only', default=False)
|
|
parser.add_argument(
|
|
'--include-extensions', action='store_true', help='Include test files of out-of-tree extensions', default=False
|
|
)
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument("--test-file", type=str, help="Path to the SQL logic file", default='')
|
|
group.add_argument(
|
|
"--test-list", type=str, help="Path to the file that contains a newline separated list of test files", default=''
|
|
)
|
|
group.add_argument("--all-tests", action='store_true', help="Run all tests", default=False)
|
|
args = parser.parse_args()
|
|
|
|
|
|
def extract_git_urls(script: str):
|
|
pattern = r'GIT_URL\s+(https?://\S+)'
|
|
return re.findall(pattern, script)
|
|
|
|
|
|
import os
|
|
import requests
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
def download_directory_contents(api_url, local_path, headers):
|
|
response = requests.get(api_url, headers=headers)
|
|
if response.status_code != 200:
|
|
print(f"⚠️ Could not access {api_url}: {response.status_code}")
|
|
return
|
|
|
|
os.makedirs(local_path, exist_ok=True)
|
|
|
|
for item in response.json():
|
|
item_type = item.get("type")
|
|
item_name = item.get("name")
|
|
if item_type == "file":
|
|
download_url = item.get("download_url")
|
|
if not download_url:
|
|
continue
|
|
file_path = os.path.join(local_path, item_name)
|
|
file_resp = requests.get(download_url)
|
|
if file_resp.status_code == 200:
|
|
with open(file_path, "wb") as f:
|
|
f.write(file_resp.content)
|
|
print(f" - Downloaded {file_path}")
|
|
else:
|
|
print(f" - Failed to download {file_path}")
|
|
elif item_type == "dir":
|
|
subdir_api_url = item.get("url")
|
|
subdir_local_path = os.path.join(local_path, item_name)
|
|
download_directory_contents(subdir_api_url, subdir_local_path, headers)
|
|
|
|
|
|
def download_test_sql_folder(repo_url, base_folder="extension-test-files"):
|
|
repo_name = urlparse(repo_url).path.strip("/").split("/")[-1]
|
|
target_folder = os.path.join(base_folder, repo_name)
|
|
|
|
if os.path.exists(target_folder):
|
|
print(f"✓ Skipping {repo_name}, already exists.")
|
|
return
|
|
|
|
print(f"⬇️ Downloading test/sql from {repo_name}...")
|
|
|
|
api_url = f"https://api.github.com/repos/duckdb/{repo_name}/contents/test/sql?ref=main"
|
|
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
|
|
headers = {"Accept": "application/vnd.github.v3+json", "Authorization": f"Bearer {GITHUB_TOKEN}"}
|
|
|
|
download_directory_contents(api_url, target_folder, headers)
|
|
|
|
|
|
def batch_download_all_test_sql():
|
|
filename = ".github/config/out_of_tree_extensions.cmake"
|
|
if not os.path.isfile(filename):
|
|
raise Exception(f"File {filename} not found")
|
|
with open(filename, "r") as f:
|
|
content = f.read()
|
|
urls = extract_git_urls(content)
|
|
if urls == []:
|
|
print("No URLs found.")
|
|
for url in urls:
|
|
download_test_sql_folder(url)
|
|
|
|
|
|
def find_tests_recursive(dir, excluded_paths):
|
|
test_list = []
|
|
for f in os.listdir(dir):
|
|
path = os.path.join(dir, f)
|
|
if path in excluded_paths:
|
|
continue
|
|
if os.path.isdir(path):
|
|
test_list += find_tests_recursive(path, excluded_paths)
|
|
elif path.endswith('.test') or path.endswith('.test_slow'):
|
|
test_list.append(path)
|
|
return test_list
|
|
|
|
|
|
def parse_test_file(filename):
|
|
if not os.path.isfile(filename):
|
|
raise Exception(f"File {filename} not found")
|
|
parser = SQLLogicParser()
|
|
try:
|
|
out: Optional[SQLLogicTest] = parser.parse(filename)
|
|
if not out:
|
|
raise SQLParserException(f"Test {filename} could not be parsed")
|
|
except:
|
|
return []
|
|
loop_count = 0
|
|
statements = []
|
|
for stmt in out.statements:
|
|
if type(stmt) is sqllogictest.statement.skip.Skip:
|
|
# mode skip - just skip entire test
|
|
break
|
|
if type(stmt) is sqllogictest.statement.loop.Loop or type(stmt) is sqllogictest.statement.foreach.Foreach:
|
|
loop_count += 1
|
|
if type(stmt) is sqllogictest.statement.endloop.Endloop:
|
|
loop_count -= 1
|
|
if loop_count > 0:
|
|
# loops are ignored currently
|
|
continue
|
|
if not (
|
|
type(stmt) is sqllogictest.statement.query.Query or type(stmt) is sqllogictest.statement.statement.Statement
|
|
):
|
|
# only handle query and statement nodes for now
|
|
continue
|
|
if type(stmt) is sqllogictest.statement.statement.Statement:
|
|
# skip expected errors
|
|
if stmt.expected_result.type == sqllogictest.ExpectedResult.Type.ERROR:
|
|
if any(
|
|
"parser error" in line.lower() or "syntax error" in line.lower()
|
|
for line in stmt.expected_result.lines
|
|
):
|
|
continue
|
|
query = ' '.join(stmt.lines)
|
|
statements.append(query)
|
|
return statements
|
|
|
|
|
|
def run_test_case(args_tuple):
|
|
i, file, shell, print_failing_only = args_tuple
|
|
results = []
|
|
if not print_failing_only:
|
|
print(f"Run test {i}: {file}")
|
|
|
|
statements = parse_test_file(file)
|
|
for statement in statements:
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
peg_sql_path = os.path.join(tmpdir, 'peg_test.sql')
|
|
with open(peg_sql_path, 'w') as f:
|
|
f.write(f'CALL check_peg_parser($TEST_PEG_PARSER${statement}$TEST_PEG_PARSER$);\n')
|
|
|
|
proc = subprocess.run([shell, '-init', peg_sql_path, '-c', '.exit'], capture_output=True)
|
|
stderr = proc.stderr.decode('utf8')
|
|
|
|
if proc.returncode == 0 and ' Error:' not in stderr:
|
|
continue
|
|
|
|
if print_failing_only:
|
|
print(f"Failed test {i}: {file}")
|
|
else:
|
|
print(f'Failed')
|
|
print(f'-- STDOUT --')
|
|
print(proc.stdout.decode('utf8'))
|
|
print(f'-- STDERR --')
|
|
print(stderr)
|
|
|
|
results.append((file, statement))
|
|
break
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
files = []
|
|
excluded_tests = {
|
|
'test/sql/peg_parser', # Fail for some reason
|
|
'test/sql/prepared/parameter_variants.test', # PostgreSQL parser bug with ?1
|
|
'test/sql/copy/s3/download_config.test', # Unknown why this passes in SQLLogicTest
|
|
'test/sql/function/list/lambdas/arrow/lambda_scope_deprecated.test', # Error in the tokenization of *+*
|
|
'test/sql/catalog/function/test_simple_macro.test', # Bug when mixing named parameters and non-named
|
|
}
|
|
if args.all_tests:
|
|
# run all tests
|
|
test_dir = os.path.join('test', 'sql')
|
|
files = find_tests_recursive(test_dir, excluded_tests)
|
|
if args.include_extensions:
|
|
batch_download_all_test_sql()
|
|
extension_files = find_tests_recursive('extension-test-files', {})
|
|
files = files + extension_files
|
|
elif len(args.test_list) > 0:
|
|
with open(args.test_list, 'r') as f:
|
|
files = [x.strip() for x in f.readlines() if x.strip() not in excluded_tests]
|
|
else:
|
|
# run a single test
|
|
files.append(args.test_file)
|
|
files.sort()
|
|
|
|
start = args.offset if args.offset is not None else 0
|
|
end = start + args.count if args.count is not None else len(files)
|
|
work_items = [(i, files[i], args.shell, args.print_failing_only) for i in range(start, end)]
|
|
|
|
if not args.no_exit:
|
|
# Disable multiprocessing for --no-exit behavior
|
|
failed_test_list = []
|
|
for item in work_items:
|
|
res = run_test_case(item)
|
|
if res:
|
|
failed_test_list.extend(res)
|
|
exit(1)
|
|
else:
|
|
with multiprocessing.Pool() as pool:
|
|
results = pool.map(run_test_case, work_items)
|
|
failed_test_list = [item for sublist in results for item in sublist]
|
|
|
|
failed_tests = len(failed_test_list)
|
|
print("List of failed tests: ")
|
|
for test, statement in failed_test_list:
|
|
print(f"{test}\n{statement}\n\n")
|
|
print(f"Total of {failed_tests} out of {len(files)} failed ({round(failed_tests/len(files) * 100,2)}%). ")
|