195 lines
6.3 KiB
Python
195 lines
6.3 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import subprocess
|
|
import itertools
|
|
import sys
|
|
from enum import Enum
|
|
from typing import Callable, Collection, Generator, Iterable, cast
|
|
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
|
|
|
MAX_FAILED = 0
|
|
BINARY = './main'
|
|
TEST_DIR = './extras/tests'
|
|
test_to_args: Callable[[str], list[str]] = lambda test: test.split()
|
|
|
|
class FloatType(str, Enum):
|
|
HALF = "half"
|
|
SINGLE = "single"
|
|
|
|
class TestType(str, Enum):
|
|
PRN = "prn"
|
|
ADD = "add"
|
|
SUB = "sub"
|
|
DIV = "div"
|
|
MUL = "mul"
|
|
FMA = "fma"
|
|
MAD = "mad"
|
|
|
|
class RoundType(str, Enum):
|
|
TO_ZERO = "to_zero"
|
|
TO_NEAREST_EVEN = "to_nearest_even"
|
|
TO_POSITIVE_INFINITY = "to_positive_infinity"
|
|
TO_NEGATIVE_INFINITY = "to_negative_infinity"
|
|
|
|
PRETTY_NAMES: dict[TestType | FloatType | RoundType, str] = {
|
|
FloatType.HALF: "Half Precision",
|
|
FloatType.SINGLE: "Single Precision",
|
|
TestType.ADD: "Addition",
|
|
TestType.SUB: "Subtraction",
|
|
TestType.DIV: "Division",
|
|
TestType.MUL: "Multiplication",
|
|
TestType.FMA: "Fused Multiply-Add",
|
|
TestType.MAD: "Multiply-Add",
|
|
TestType.PRN: "Print",
|
|
RoundType.TO_ZERO: "Round Toward Zero",
|
|
RoundType.TO_NEAREST_EVEN: "Round to Nearest Even",
|
|
RoundType.TO_POSITIVE_INFINITY: "Round Toward +∞",
|
|
RoundType.TO_NEGATIVE_INFINITY: "Round Toward −∞",
|
|
}
|
|
|
|
|
|
FILL_SIZE = os.get_terminal_size().columns
|
|
CONCURRENT_WORKERS = 32
|
|
total_bad = 0
|
|
filled = lambda x: f"{x}{' ' * (FILL_SIZE - len(str(x).replace('\r', '').replace('\n', '')))}"
|
|
filledn = lambda x, n: f"{x}{' ' * (n - len(str(x).replace('\r', '').replace('\n', '')))}"
|
|
|
|
|
|
def get_path(float_type: FloatType, test_type: TestType, round_type: RoundType) -> str:
|
|
return f'{TEST_DIR}/{float_type.name.lower()}/{test_type.name.lower()}/{round_type.name.lower()}.tsv'
|
|
|
|
|
|
def read_tests_from_file(path: str):
|
|
with open(path) as f:
|
|
return tuple(cast(tuple[str, str], tuple(line.strip().split('\t'))) for line in f)
|
|
|
|
|
|
def run_single_test(i: int, test: tuple[str, str]) -> tuple[int, str, str, str]:
|
|
try:
|
|
ps = subprocess.run([BINARY, *test_to_args(test[0])],
|
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE, timeout=0.5)
|
|
out = ps.stdout.decode().strip().split('\n')[0]
|
|
except subprocess.TimeoutExpired:
|
|
out = 'TIMEOUT'
|
|
|
|
expected = ' '.join(test[1:])
|
|
return (i, test[0], expected, out)
|
|
|
|
|
|
def batch_submit[T](
|
|
executor: ThreadPoolExecutor,
|
|
fn: Callable[..., T],
|
|
iterable: Iterable[Collection]
|
|
) -> Generator[Future[T]]:
|
|
for batch in itertools.batched(iterable, CONCURRENT_WORKERS):
|
|
for future in as_completed(
|
|
executor.submit(fn, *it)
|
|
for it in batch
|
|
):
|
|
yield future
|
|
|
|
def run_tests(tests: tuple[tuple[str, str], ...]) -> tuple[int, int]:
|
|
global total_bad
|
|
good = bad = 0
|
|
|
|
with ThreadPoolExecutor(max_workers=CONCURRENT_WORKERS) as executor:
|
|
for future in batch_submit(executor, run_single_test, enumerate(tests)):
|
|
i, input_data, expected, out = future.result()
|
|
if out != expected:
|
|
total_bad += 1
|
|
bad += 1
|
|
print('\n'.join((
|
|
filled(f'\rTEST #{i + 1}' + (' (TIMEOUT)' if out == 'TIMEOUT' else '')),
|
|
filled(f'INPUT: {input_data}'),
|
|
filled(f'OUTPUT: {out}'),
|
|
filled(f'EXPECTED: {expected}'),
|
|
)), end='\n\n')
|
|
else:
|
|
good += 1
|
|
|
|
print(filled(f'\rGOOD: {good}, BAD: {bad}, TOTAL: {len(tests)}'), end='')
|
|
|
|
if MAX_FAILED and total_bad >= MAX_FAILED:
|
|
print("\nToo many failed tests!")
|
|
os._exit(1) # immediate exit from all threads
|
|
print(end='\r')
|
|
|
|
return good, bad
|
|
|
|
|
|
def parse_args[T](args: str, names: dict[str, T]) -> tuple[T, ...]:
|
|
return tuple(names[name] for name in names if name in args)
|
|
|
|
|
|
def main():
|
|
summary: dict[str, tuple[int, int]] = {}
|
|
|
|
float_types = tuple(iter(FloatType))
|
|
test_types = tuple(iter(TestType))
|
|
round_types = tuple(iter(RoundType))
|
|
|
|
for _ in ((),):
|
|
if len(sys.argv) < 2: break
|
|
float_types = parse_args(sys.argv[1], {
|
|
'h': FloatType.HALF,
|
|
's': FloatType.SINGLE
|
|
})
|
|
|
|
if len(sys.argv) < 3: break
|
|
test_types = parse_args(sys.argv[2], {
|
|
'p': TestType.PRN,
|
|
'+': TestType.ADD,
|
|
'-': TestType.SUB,
|
|
'/': TestType.DIV,
|
|
'*': TestType.MUL,
|
|
'f': TestType.FMA,
|
|
'm': TestType.MAD
|
|
})
|
|
|
|
if len(sys.argv) < 4: break
|
|
round_types = parse_args(sys.argv[3], {
|
|
'0': RoundType.TO_ZERO,
|
|
'1': RoundType.TO_NEAREST_EVEN,
|
|
'2': RoundType.TO_POSITIVE_INFINITY,
|
|
'3': RoundType.TO_NEGATIVE_INFINITY
|
|
})
|
|
|
|
for float_type in float_types:
|
|
for test_type in test_types:
|
|
for round_type in round_types:
|
|
test_path = get_path(float_type, test_type, round_type)
|
|
|
|
try:
|
|
tests = read_tests_from_file(test_path)
|
|
except FileNotFoundError:
|
|
print(filled(f"\nSkipping missing test file: {test_path}"))
|
|
continue
|
|
|
|
print(f"Running tests: {PRETTY_NAMES[float_type]} / {PRETTY_NAMES[test_type]} / {PRETTY_NAMES[round_type]}")
|
|
good, bad = run_tests(tests)
|
|
key = f"{PRETTY_NAMES[float_type]} / {PRETTY_NAMES[test_type]} / {PRETTY_NAMES[round_type]}"
|
|
summary[key] = (good, bad)
|
|
|
|
print(filled(''))
|
|
print(filled('==== TEST SUMMARY ===='))
|
|
|
|
total_good = total_bad = 0
|
|
maxa = max((len(x[0]) for x in summary.items()))
|
|
maxb = max((len(str(x[1][0])) for x in summary.items()))
|
|
maxc = max((len(str(x[1][1])) for x in summary.items()))
|
|
for category, (good, bad) in summary.items():
|
|
print(f"{filledn(category, maxa)} | GOOD: {filledn(good, maxb)} | BAD: {filledn(bad, maxc)}")
|
|
total_good += good
|
|
total_bad += bad
|
|
|
|
print('\n==== TOTAL ====')
|
|
print(f"GOOD: {filledn(total_good, maxb)} | BAD: {filledn(total_bad, maxc)}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
print()
|
|
exit(0)
|