Files
email-tracker/external/duckdb/scripts/plan_cost_runner.py
2025-10-24 19:21:19 -05:00

208 lines
6.4 KiB
Python

import argparse
import glob
import json
import os
import subprocess
import sys
from tqdm import tqdm
OLD_DB_NAME = "old.duckdb"
NEW_DB_NAME = "new.duckdb"
PROFILE_FILENAME = "duckdb_profile.json"
ENABLE_PROFILING = "PRAGMA enable_profiling=json"
PROFILE_OUTPUT = f"PRAGMA profile_output='{PROFILE_FILENAME}'"
BANNER_SIZE = 52
def init_db(cli, dbname, benchmark_dir):
print(f"INITIALIZING {dbname} ...")
subprocess.run(
f"{cli} {dbname} < {benchmark_dir}/init/schema.sql", shell=True, check=True, stdout=subprocess.DEVNULL
)
subprocess.run(f"{cli} {dbname} < {benchmark_dir}/init/load.sql", shell=True, check=True, stdout=subprocess.DEVNULL)
print("INITIALIZATION DONE")
class PlanCost:
def __init__(self):
self.total = 0
self.build_side = 0
self.probe_side = 0
self.time = 0
def __add__(self, other):
self.total += other.total
self.build_side += other.build_side
self.probe_side += other.probe_side
return self
def __gt__(self, other):
if self == other or self.total < other.total:
return False
# if the total intermediate cardinalities is greater, also inspect time.
# it's possible a plan reordering increased cardinalities, but overall execution time
# was not greatly affected
total_card_increased = self.total > other.total
build_card_increased = self.build_side > other.build_side
if total_card_increased and build_card_increased:
return True
# we know the total cardinality is either the same or higher and the build side has not increased
# in this case fall back to the timing. It's possible that even if the probe side is higher
# since the tuples are in flight, the plan executes faster
return self.time > other.time * 1.03
def __lt__(self, other):
if self == other:
return False
return not (self > other)
def __eq__(self, other):
return self.total == other.total and self.build_side == other.build_side and self.probe_side == other.probe_side
def is_measured_join(op) -> bool:
if 'name' not in op:
return False
if op['name'] != 'HASH_JOIN':
return False
if 'Join Type' not in op['extra_info']:
return False
if op['extra_info']['Join Type'].startswith('MARK'):
return False
return True
def op_inspect(op) -> PlanCost:
cost = PlanCost()
if 'Query' in op:
cost.time = op['operator_timing']
if is_measured_join(op):
cost.total = op['operator_cardinality']
if 'operator_cardinality' in op['children'][0]:
cost.probe_side += op['children'][0]['operator_cardinality']
if 'operator_cardinality' in op['children'][1]:
cost.build_side += op['children'][1]['operator_cardinality']
left_cost = op_inspect(op['children'][0])
right_cost = op_inspect(op['children'][1])
cost.probe_side += left_cost.probe_side + right_cost.probe_side
cost.build_side += left_cost.build_side + right_cost.build_side
cost.total += left_cost.total + right_cost.total
return cost
for child_op in op['children']:
cost += op_inspect(child_op)
return cost
def query_plan_cost(cli, dbname, query):
try:
subprocess.run(
f"{cli} --readonly {dbname} -c \"{ENABLE_PROFILING};{PROFILE_OUTPUT};{query}\"",
shell=True,
check=True,
capture_output=True,
)
except subprocess.CalledProcessError as e:
print("-------------------------")
print("--------Failure----------")
print("-------------------------")
print(e.stderr.decode('utf8'))
print("-------------------------")
print("--------Output----------")
print("-------------------------")
print(e.output.decode('utf8'))
print("-------------------------")
raise e
with open(PROFILE_FILENAME, 'r') as file:
return op_inspect(json.load(file))
def print_banner(text):
text_len = len(text)
rest = BANNER_SIZE - text_len - 10
l_width = int(rest / 2)
r_width = l_width
if rest % 2 != 0:
l_width += 1
print("")
print("=" * BANNER_SIZE)
print("=" * l_width + " " * 5 + text + " " * 5 + "=" * r_width)
print("=" * BANNER_SIZE)
def print_diffs(diffs):
for query_name, old_cost, new_cost in diffs:
print("")
print("Query:", query_name)
print("Old total cost:", old_cost.total)
print("Old build cost:", old_cost.build_side)
print("Old probe cost:", old_cost.probe_side)
print("New total cost:", new_cost.total)
print("New build cost:", new_cost.build_side)
print("New probe cost:", new_cost.probe_side)
def main():
parser = argparse.ArgumentParser(description="Plan cost regression test script with old and new versions.")
parser.add_argument("--old", type=str, help="Path to the old runner.", required=True)
parser.add_argument("--new", type=str, help="Path to the new runner.", required=True)
parser.add_argument("--dir", type=str, help="Path to the benchmark directory.", required=True)
args = parser.parse_args()
old = args.old
new = args.new
benchmark_dir = args.dir
init_db(old, OLD_DB_NAME, benchmark_dir)
init_db(new, NEW_DB_NAME, benchmark_dir)
improvements = []
regressions = []
files = glob.glob(f"{benchmark_dir}/queries/*.sql")
files.sort()
print("")
print("RUNNING BENCHMARK QUERIES")
for f in tqdm(files):
query_name = f.split("/")[-1].replace(".sql", "")
with open(f, "r") as file:
query = file.read()
old_cost = query_plan_cost(old, OLD_DB_NAME, query)
new_cost = query_plan_cost(new, NEW_DB_NAME, query)
if old_cost > new_cost:
improvements.append((query_name, old_cost, new_cost))
elif new_cost > old_cost:
regressions.append((query_name, old_cost, new_cost))
exit_code = 0
if improvements:
print_banner("IMPROVEMENTS DETECTED")
print_diffs(improvements)
if regressions:
exit_code = 1
print_banner("REGRESSIONS DETECTED")
print_diffs(regressions)
if not improvements and not regressions:
print_banner("NO DIFFERENCES DETECTED")
os.remove(OLD_DB_NAME)
os.remove(NEW_DB_NAME)
os.remove(PROFILE_FILENAME)
exit(exit_code)
if __name__ == "__main__":
main()