#!/usr/bin/env python3 import os import subprocess import itertools import sys from enum import Enum from typing import Callable, Collection, Generator, Iterable, cast from concurrent.futures import Future, ThreadPoolExecutor, as_completed MAX_FAILED = 0 BINARY = './main' TEST_DIR = './extras/tests' test_to_args: Callable[[str], list[str]] = lambda test: test.split() class FloatType(str, Enum): HALF = "half" SINGLE = "single" class TestType(str, Enum): PRN = "prn" ADD = "add" SUB = "sub" DIV = "div" MUL = "mul" FMA = "fma" MAD = "mad" class RoundType(str, Enum): TO_ZERO = "to_zero" TO_NEAREST_EVEN = "to_nearest_even" TO_POSITIVE_INFINITY = "to_positive_infinity" TO_NEGATIVE_INFINITY = "to_negative_infinity" PRETTY_NAMES: dict[TestType | FloatType | RoundType, str] = { FloatType.HALF: "Half Precision", FloatType.SINGLE: "Single Precision", TestType.ADD: "Addition", TestType.SUB: "Subtraction", TestType.DIV: "Division", TestType.MUL: "Multiplication", TestType.FMA: "Fused Multiply-Add", TestType.MAD: "Multiply-Add", TestType.PRN: "Print", RoundType.TO_ZERO: "Round Toward Zero", RoundType.TO_NEAREST_EVEN: "Round to Nearest Even", RoundType.TO_POSITIVE_INFINITY: "Round Toward +∞", RoundType.TO_NEGATIVE_INFINITY: "Round Toward −∞", } FILL_SIZE = os.get_terminal_size().columns CONCURRENT_WORKERS = 32 total_bad = 0 filled = lambda x: f"{x}{' ' * (FILL_SIZE - len(str(x).replace('\r', '').replace('\n', '')))}" filledn = lambda x, n: f"{x}{' ' * (n - len(str(x).replace('\r', '').replace('\n', '')))}" def get_path(float_type: FloatType, test_type: TestType, round_type: RoundType) -> str: return f'{TEST_DIR}/{float_type.name.lower()}/{test_type.name.lower()}/{round_type.name.lower()}.tsv' def read_tests_from_file(path: str): with open(path) as f: return tuple(cast(tuple[str, str], tuple(line.strip().split('\t'))) for line in f) def run_single_test(i: int, test: tuple[str, str]) -> tuple[int, str, str, str]: try: ps = subprocess.run([BINARY, *test_to_args(test[0])], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=0.5) out = ps.stdout.decode().strip().split('\n')[0] except subprocess.TimeoutExpired: out = 'TIMEOUT' expected = ' '.join(test[1:]) return (i, test[0], expected, out) def batch_submit[T]( executor: ThreadPoolExecutor, fn: Callable[..., T], iterable: Iterable[Collection] ) -> Generator[Future[T]]: for batch in itertools.batched(iterable, CONCURRENT_WORKERS): for future in as_completed( executor.submit(fn, *it) for it in batch ): yield future def run_tests(tests: tuple[tuple[str, str], ...]) -> tuple[int, int]: global total_bad good = bad = 0 with ThreadPoolExecutor(max_workers=CONCURRENT_WORKERS) as executor: for future in batch_submit(executor, run_single_test, enumerate(tests)): i, input_data, expected, out = future.result() if out != expected: total_bad += 1 bad += 1 print('\n'.join(( filled(f'\rTEST #{i + 1}' + (' (TIMEOUT)' if out == 'TIMEOUT' else '')), filled(f'INPUT: {input_data}'), filled(f'OUTPUT: {out}'), filled(f'EXPECTED: {expected}'), )), end='\n\n') else: good += 1 print(filled(f'\rGOOD: {good}, BAD: {bad}, TOTAL: {len(tests)}'), end='') if MAX_FAILED and total_bad >= MAX_FAILED: print("\nToo many failed tests!") os._exit(1) # immediate exit from all threads print(end='\r') return good, bad def parse_args[T](args: str, names: dict[str, T]) -> tuple[T, ...]: return tuple(names[name] for name in names if name in args) def main(): summary: dict[str, tuple[int, int]] = {} float_types = tuple(iter(FloatType)) test_types = tuple(iter(TestType)) round_types = tuple(iter(RoundType)) for _ in ((),): if len(sys.argv) < 2: break float_types = parse_args(sys.argv[1], { 'h': FloatType.HALF, 's': FloatType.SINGLE }) if len(sys.argv) < 3: break test_types = parse_args(sys.argv[2], { 'p': TestType.PRN, '+': TestType.ADD, '-': TestType.SUB, '/': TestType.DIV, '*': TestType.MUL, 'f': TestType.FMA, 'm': TestType.MAD }) if len(sys.argv) < 4: break round_types = parse_args(sys.argv[3], { '0': RoundType.TO_ZERO, '1': RoundType.TO_NEAREST_EVEN, '2': RoundType.TO_POSITIVE_INFINITY, '3': RoundType.TO_NEGATIVE_INFINITY }) for float_type in float_types: for test_type in test_types: for round_type in round_types: test_path = get_path(float_type, test_type, round_type) try: tests = read_tests_from_file(test_path) except FileNotFoundError: print(filled(f"\nSkipping missing test file: {test_path}")) continue print(f"Running tests: {PRETTY_NAMES[float_type]} / {PRETTY_NAMES[test_type]} / {PRETTY_NAMES[round_type]}") good, bad = run_tests(tests) key = f"{PRETTY_NAMES[float_type]} / {PRETTY_NAMES[test_type]} / {PRETTY_NAMES[round_type]}" summary[key] = (good, bad) print(filled('')) print(filled('==== TEST SUMMARY ====')) total_good = total_bad = 0 maxa = max((len(x[0]) for x in summary.items())) maxb = max((len(str(x[1][0])) for x in summary.items())) maxc = max((len(str(x[1][1])) for x in summary.items())) for category, (good, bad) in summary.items(): print(f"{filledn(category, maxa)} | GOOD: {filledn(good, maxb)} | BAD: {filledn(bad, maxc)}") total_good += good total_bad += bad print('\n==== TOTAL ====') print(f"GOOD: {filledn(total_good, maxb)} | BAD: {filledn(total_bad, maxc)}") if __name__ == '__main__': try: main() except KeyboardInterrupt: print() exit(0)