should be it
This commit is contained in:
230
external/duckdb/scripts/test_peg_parser.py
vendored
Normal file
230
external/duckdb/scripts/test_peg_parser.py
vendored
Normal file
@@ -0,0 +1,230 @@
|
||||
import argparse
|
||||
import os
|
||||
import sqllogictest
|
||||
from sqllogictest import SQLParserException, SQLLogicParser, SQLLogicTest
|
||||
import subprocess
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import re
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test serialization")
|
||||
parser.add_argument("--shell", type=str, help="Shell binary to run", default=os.path.join('build', 'debug', 'duckdb'))
|
||||
parser.add_argument("--offset", type=int, help="File offset", default=None)
|
||||
parser.add_argument("--count", type=int, help="File count", default=None)
|
||||
parser.add_argument('--no-exit', action='store_true', help='Do not exit after a test fails', default=False)
|
||||
parser.add_argument('--print-failing-only', action='store_true', help='Print failing tests only', default=False)
|
||||
parser.add_argument(
|
||||
'--include-extensions', action='store_true', help='Include test files of out-of-tree extensions', default=False
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--test-file", type=str, help="Path to the SQL logic file", default='')
|
||||
group.add_argument(
|
||||
"--test-list", type=str, help="Path to the file that contains a newline separated list of test files", default=''
|
||||
)
|
||||
group.add_argument("--all-tests", action='store_true', help="Run all tests", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def extract_git_urls(script: str):
|
||||
pattern = r'GIT_URL\s+(https?://\S+)'
|
||||
return re.findall(pattern, script)
|
||||
|
||||
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def download_directory_contents(api_url, local_path, headers):
|
||||
response = requests.get(api_url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"⚠️ Could not access {api_url}: {response.status_code}")
|
||||
return
|
||||
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
|
||||
for item in response.json():
|
||||
item_type = item.get("type")
|
||||
item_name = item.get("name")
|
||||
if item_type == "file":
|
||||
download_url = item.get("download_url")
|
||||
if not download_url:
|
||||
continue
|
||||
file_path = os.path.join(local_path, item_name)
|
||||
file_resp = requests.get(download_url)
|
||||
if file_resp.status_code == 200:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_resp.content)
|
||||
print(f" - Downloaded {file_path}")
|
||||
else:
|
||||
print(f" - Failed to download {file_path}")
|
||||
elif item_type == "dir":
|
||||
subdir_api_url = item.get("url")
|
||||
subdir_local_path = os.path.join(local_path, item_name)
|
||||
download_directory_contents(subdir_api_url, subdir_local_path, headers)
|
||||
|
||||
|
||||
def download_test_sql_folder(repo_url, base_folder="extension-test-files"):
|
||||
repo_name = urlparse(repo_url).path.strip("/").split("/")[-1]
|
||||
target_folder = os.path.join(base_folder, repo_name)
|
||||
|
||||
if os.path.exists(target_folder):
|
||||
print(f"✓ Skipping {repo_name}, already exists.")
|
||||
return
|
||||
|
||||
print(f"⬇️ Downloading test/sql from {repo_name}...")
|
||||
|
||||
api_url = f"https://api.github.com/repos/duckdb/{repo_name}/contents/test/sql?ref=main"
|
||||
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
|
||||
headers = {"Accept": "application/vnd.github.v3+json", "Authorization": f"Bearer {GITHUB_TOKEN}"}
|
||||
|
||||
download_directory_contents(api_url, target_folder, headers)
|
||||
|
||||
|
||||
def batch_download_all_test_sql():
|
||||
filename = ".github/config/out_of_tree_extensions.cmake"
|
||||
if not os.path.isfile(filename):
|
||||
raise Exception(f"File {filename} not found")
|
||||
with open(filename, "r") as f:
|
||||
content = f.read()
|
||||
urls = extract_git_urls(content)
|
||||
if urls == []:
|
||||
print("No URLs found.")
|
||||
for url in urls:
|
||||
download_test_sql_folder(url)
|
||||
|
||||
|
||||
def find_tests_recursive(dir, excluded_paths):
|
||||
test_list = []
|
||||
for f in os.listdir(dir):
|
||||
path = os.path.join(dir, f)
|
||||
if path in excluded_paths:
|
||||
continue
|
||||
if os.path.isdir(path):
|
||||
test_list += find_tests_recursive(path, excluded_paths)
|
||||
elif path.endswith('.test') or path.endswith('.test_slow'):
|
||||
test_list.append(path)
|
||||
return test_list
|
||||
|
||||
|
||||
def parse_test_file(filename):
|
||||
if not os.path.isfile(filename):
|
||||
raise Exception(f"File {filename} not found")
|
||||
parser = SQLLogicParser()
|
||||
try:
|
||||
out: Optional[SQLLogicTest] = parser.parse(filename)
|
||||
if not out:
|
||||
raise SQLParserException(f"Test {filename} could not be parsed")
|
||||
except:
|
||||
return []
|
||||
loop_count = 0
|
||||
statements = []
|
||||
for stmt in out.statements:
|
||||
if type(stmt) is sqllogictest.statement.skip.Skip:
|
||||
# mode skip - just skip entire test
|
||||
break
|
||||
if type(stmt) is sqllogictest.statement.loop.Loop or type(stmt) is sqllogictest.statement.foreach.Foreach:
|
||||
loop_count += 1
|
||||
if type(stmt) is sqllogictest.statement.endloop.Endloop:
|
||||
loop_count -= 1
|
||||
if loop_count > 0:
|
||||
# loops are ignored currently
|
||||
continue
|
||||
if not (
|
||||
type(stmt) is sqllogictest.statement.query.Query or type(stmt) is sqllogictest.statement.statement.Statement
|
||||
):
|
||||
# only handle query and statement nodes for now
|
||||
continue
|
||||
if type(stmt) is sqllogictest.statement.statement.Statement:
|
||||
# skip expected errors
|
||||
if stmt.expected_result.type == sqllogictest.ExpectedResult.Type.ERROR:
|
||||
if any(
|
||||
"parser error" in line.lower() or "syntax error" in line.lower()
|
||||
for line in stmt.expected_result.lines
|
||||
):
|
||||
continue
|
||||
query = ' '.join(stmt.lines)
|
||||
statements.append(query)
|
||||
return statements
|
||||
|
||||
|
||||
def run_test_case(args_tuple):
|
||||
i, file, shell, print_failing_only = args_tuple
|
||||
results = []
|
||||
if not print_failing_only:
|
||||
print(f"Run test {i}: {file}")
|
||||
|
||||
statements = parse_test_file(file)
|
||||
for statement in statements:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
peg_sql_path = os.path.join(tmpdir, 'peg_test.sql')
|
||||
with open(peg_sql_path, 'w') as f:
|
||||
f.write(f'CALL check_peg_parser($TEST_PEG_PARSER${statement}$TEST_PEG_PARSER$);\n')
|
||||
|
||||
proc = subprocess.run([shell, '-init', peg_sql_path, '-c', '.exit'], capture_output=True)
|
||||
stderr = proc.stderr.decode('utf8')
|
||||
|
||||
if proc.returncode == 0 and ' Error:' not in stderr:
|
||||
continue
|
||||
|
||||
if print_failing_only:
|
||||
print(f"Failed test {i}: {file}")
|
||||
else:
|
||||
print(f'Failed')
|
||||
print(f'-- STDOUT --')
|
||||
print(proc.stdout.decode('utf8'))
|
||||
print(f'-- STDERR --')
|
||||
print(stderr)
|
||||
|
||||
results.append((file, statement))
|
||||
break
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = []
|
||||
excluded_tests = {
|
||||
'test/sql/peg_parser', # Fail for some reason
|
||||
'test/sql/prepared/parameter_variants.test', # PostgreSQL parser bug with ?1
|
||||
'test/sql/copy/s3/download_config.test', # Unknown why this passes in SQLLogicTest
|
||||
'test/sql/function/list/lambdas/arrow/lambda_scope_deprecated.test', # Error in the tokenization of *+*
|
||||
'test/sql/catalog/function/test_simple_macro.test', # Bug when mixing named parameters and non-named
|
||||
}
|
||||
if args.all_tests:
|
||||
# run all tests
|
||||
test_dir = os.path.join('test', 'sql')
|
||||
files = find_tests_recursive(test_dir, excluded_tests)
|
||||
if args.include_extensions:
|
||||
batch_download_all_test_sql()
|
||||
extension_files = find_tests_recursive('extension-test-files', {})
|
||||
files = files + extension_files
|
||||
elif len(args.test_list) > 0:
|
||||
with open(args.test_list, 'r') as f:
|
||||
files = [x.strip() for x in f.readlines() if x.strip() not in excluded_tests]
|
||||
else:
|
||||
# run a single test
|
||||
files.append(args.test_file)
|
||||
files.sort()
|
||||
|
||||
start = args.offset if args.offset is not None else 0
|
||||
end = start + args.count if args.count is not None else len(files)
|
||||
work_items = [(i, files[i], args.shell, args.print_failing_only) for i in range(start, end)]
|
||||
|
||||
if not args.no_exit:
|
||||
# Disable multiprocessing for --no-exit behavior
|
||||
failed_test_list = []
|
||||
for item in work_items:
|
||||
res = run_test_case(item)
|
||||
if res:
|
||||
failed_test_list.extend(res)
|
||||
exit(1)
|
||||
else:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.map(run_test_case, work_items)
|
||||
failed_test_list = [item for sublist in results for item in sublist]
|
||||
|
||||
failed_tests = len(failed_test_list)
|
||||
print("List of failed tests: ")
|
||||
for test, statement in failed_test_list:
|
||||
print(f"{test}\n{statement}\n\n")
|
||||
print(f"Total of {failed_tests} out of {len(files)} failed ({round(failed_tests/len(files) * 100,2)}%). ")
|
||||
Reference in New Issue
Block a user