208 lines
6.4 KiB
Python
208 lines
6.4 KiB
Python
import argparse
|
|
import glob
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from tqdm import tqdm
|
|
|
|
|
|
OLD_DB_NAME = "old.duckdb"
|
|
NEW_DB_NAME = "new.duckdb"
|
|
PROFILE_FILENAME = "duckdb_profile.json"
|
|
|
|
ENABLE_PROFILING = "PRAGMA enable_profiling=json"
|
|
PROFILE_OUTPUT = f"PRAGMA profile_output='{PROFILE_FILENAME}'"
|
|
|
|
BANNER_SIZE = 52
|
|
|
|
|
|
def init_db(cli, dbname, benchmark_dir):
|
|
print(f"INITIALIZING {dbname} ...")
|
|
subprocess.run(
|
|
f"{cli} {dbname} < {benchmark_dir}/init/schema.sql", shell=True, check=True, stdout=subprocess.DEVNULL
|
|
)
|
|
subprocess.run(f"{cli} {dbname} < {benchmark_dir}/init/load.sql", shell=True, check=True, stdout=subprocess.DEVNULL)
|
|
print("INITIALIZATION DONE")
|
|
|
|
|
|
class PlanCost:
|
|
def __init__(self):
|
|
self.total = 0
|
|
self.build_side = 0
|
|
self.probe_side = 0
|
|
self.time = 0
|
|
|
|
def __add__(self, other):
|
|
self.total += other.total
|
|
self.build_side += other.build_side
|
|
self.probe_side += other.probe_side
|
|
return self
|
|
|
|
def __gt__(self, other):
|
|
if self == other or self.total < other.total:
|
|
return False
|
|
# if the total intermediate cardinalities is greater, also inspect time.
|
|
# it's possible a plan reordering increased cardinalities, but overall execution time
|
|
# was not greatly affected
|
|
total_card_increased = self.total > other.total
|
|
build_card_increased = self.build_side > other.build_side
|
|
if total_card_increased and build_card_increased:
|
|
return True
|
|
# we know the total cardinality is either the same or higher and the build side has not increased
|
|
# in this case fall back to the timing. It's possible that even if the probe side is higher
|
|
# since the tuples are in flight, the plan executes faster
|
|
return self.time > other.time * 1.03
|
|
|
|
def __lt__(self, other):
|
|
if self == other:
|
|
return False
|
|
return not (self > other)
|
|
|
|
def __eq__(self, other):
|
|
return self.total == other.total and self.build_side == other.build_side and self.probe_side == other.probe_side
|
|
|
|
|
|
def is_measured_join(op) -> bool:
|
|
if 'name' not in op:
|
|
return False
|
|
if op['name'] != 'HASH_JOIN':
|
|
return False
|
|
if 'Join Type' not in op['extra_info']:
|
|
return False
|
|
if op['extra_info']['Join Type'].startswith('MARK'):
|
|
return False
|
|
return True
|
|
|
|
|
|
def op_inspect(op) -> PlanCost:
|
|
cost = PlanCost()
|
|
if 'Query' in op:
|
|
cost.time = op['operator_timing']
|
|
if is_measured_join(op):
|
|
cost.total = op['operator_cardinality']
|
|
if 'operator_cardinality' in op['children'][0]:
|
|
cost.probe_side += op['children'][0]['operator_cardinality']
|
|
if 'operator_cardinality' in op['children'][1]:
|
|
cost.build_side += op['children'][1]['operator_cardinality']
|
|
|
|
left_cost = op_inspect(op['children'][0])
|
|
right_cost = op_inspect(op['children'][1])
|
|
cost.probe_side += left_cost.probe_side + right_cost.probe_side
|
|
cost.build_side += left_cost.build_side + right_cost.build_side
|
|
cost.total += left_cost.total + right_cost.total
|
|
return cost
|
|
|
|
for child_op in op['children']:
|
|
cost += op_inspect(child_op)
|
|
|
|
return cost
|
|
|
|
|
|
def query_plan_cost(cli, dbname, query):
|
|
try:
|
|
subprocess.run(
|
|
f"{cli} --readonly {dbname} -c \"{ENABLE_PROFILING};{PROFILE_OUTPUT};{query}\"",
|
|
shell=True,
|
|
check=True,
|
|
capture_output=True,
|
|
)
|
|
except subprocess.CalledProcessError as e:
|
|
print("-------------------------")
|
|
print("--------Failure----------")
|
|
print("-------------------------")
|
|
print(e.stderr.decode('utf8'))
|
|
print("-------------------------")
|
|
print("--------Output----------")
|
|
print("-------------------------")
|
|
print(e.output.decode('utf8'))
|
|
print("-------------------------")
|
|
raise e
|
|
with open(PROFILE_FILENAME, 'r') as file:
|
|
return op_inspect(json.load(file))
|
|
|
|
|
|
def print_banner(text):
|
|
text_len = len(text)
|
|
rest = BANNER_SIZE - text_len - 10
|
|
l_width = int(rest / 2)
|
|
r_width = l_width
|
|
if rest % 2 != 0:
|
|
l_width += 1
|
|
print("")
|
|
print("=" * BANNER_SIZE)
|
|
print("=" * l_width + " " * 5 + text + " " * 5 + "=" * r_width)
|
|
print("=" * BANNER_SIZE)
|
|
|
|
|
|
def print_diffs(diffs):
|
|
for query_name, old_cost, new_cost in diffs:
|
|
print("")
|
|
print("Query:", query_name)
|
|
print("Old total cost:", old_cost.total)
|
|
print("Old build cost:", old_cost.build_side)
|
|
print("Old probe cost:", old_cost.probe_side)
|
|
print("New total cost:", new_cost.total)
|
|
print("New build cost:", new_cost.build_side)
|
|
print("New probe cost:", new_cost.probe_side)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Plan cost regression test script with old and new versions.")
|
|
|
|
parser.add_argument("--old", type=str, help="Path to the old runner.", required=True)
|
|
parser.add_argument("--new", type=str, help="Path to the new runner.", required=True)
|
|
parser.add_argument("--dir", type=str, help="Path to the benchmark directory.", required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
old = args.old
|
|
new = args.new
|
|
benchmark_dir = args.dir
|
|
|
|
init_db(old, OLD_DB_NAME, benchmark_dir)
|
|
init_db(new, NEW_DB_NAME, benchmark_dir)
|
|
|
|
improvements = []
|
|
regressions = []
|
|
|
|
files = glob.glob(f"{benchmark_dir}/queries/*.sql")
|
|
files.sort()
|
|
|
|
print("")
|
|
print("RUNNING BENCHMARK QUERIES")
|
|
for f in tqdm(files):
|
|
query_name = f.split("/")[-1].replace(".sql", "")
|
|
|
|
with open(f, "r") as file:
|
|
query = file.read()
|
|
|
|
old_cost = query_plan_cost(old, OLD_DB_NAME, query)
|
|
new_cost = query_plan_cost(new, NEW_DB_NAME, query)
|
|
|
|
if old_cost > new_cost:
|
|
improvements.append((query_name, old_cost, new_cost))
|
|
elif new_cost > old_cost:
|
|
regressions.append((query_name, old_cost, new_cost))
|
|
|
|
exit_code = 0
|
|
if improvements:
|
|
print_banner("IMPROVEMENTS DETECTED")
|
|
print_diffs(improvements)
|
|
if regressions:
|
|
exit_code = 1
|
|
print_banner("REGRESSIONS DETECTED")
|
|
print_diffs(regressions)
|
|
if not improvements and not regressions:
|
|
print_banner("NO DIFFERENCES DETECTED")
|
|
|
|
os.remove(OLD_DB_NAME)
|
|
os.remove(NEW_DB_NAME)
|
|
os.remove(PROFILE_FILENAME)
|
|
|
|
exit(exit_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|