1
0
Fork 0
ca-tools/float/test.py
2025-07-17 20:14:55 +03:00

166 lines
5.1 KiB
Python

#!/usr/bin/env python3
import os
import subprocess
from enum import Enum
import sys
from typing import Callable, cast
MAX_FAILED = 1
BINARY = './main'
TEST_DIR = './extras/tests'
test_to_args: Callable[[str], list[str]] = lambda test: test.split()
class FloatType(str, Enum):
HALF = "half"
SINGLE = "single"
class TestType(str, Enum):
PRN = "prn"
ADD = "add"
SUB = "sub"
DIV = "div"
MUL = "mul"
FMA = "fma"
MAD = "mad"
class RoundType(str, Enum):
TO_ZERO = "to_zero"
TO_NEAREST_EVEN = "to_nearest_even"
TO_POSITIVE_INFINITY = "to_positive_infinity"
TO_NEGATIVE_INFINITY = "to_negative_infinity"
PRETTY_NAMES: dict[TestType | FloatType | RoundType, str] = {
FloatType.HALF: "Half Precision",
FloatType.SINGLE: "Single Precision",
TestType.ADD: "Addition",
TestType.SUB: "Subtraction",
TestType.DIV: "Division",
TestType.MUL: "Multiplication",
TestType.FMA: "Fused Multiply-Add",
TestType.MAD: "Multiply-Add",
TestType.PRN: "Print",
RoundType.TO_ZERO: "Round Toward Zero",
RoundType.TO_NEAREST_EVEN: "Round to Nearest Even",
RoundType.TO_POSITIVE_INFINITY: "Round Toward +∞",
RoundType.TO_NEGATIVE_INFINITY: "Round Toward −∞",
}
FILL_SIZE = os.get_terminal_size().columns
total_bad = 0
filled = lambda x: f"{x}{' ' * (FILL_SIZE - len(x.replace('\r', '').replace('\n', '')))}"
def get_path(float_type: FloatType, test_type: TestType, round_type: RoundType) -> str:
return f'{TEST_DIR}/{float_type.name.lower()}/{test_type.name.lower()}/{round_type.name.lower()}.tsv'
def read_tests_from_file(path: str):
with open(path) as f:
return tuple(cast(tuple[str, str], tuple(line.strip().split('\t'))) for line in f)
def run_tests(executable: str, tests: tuple[tuple[str, str], ...]) -> tuple[int, int]:
global total_bad
good = bad = 0
for i, test in enumerate(tests):
try:
ps = subprocess.run([executable, *test_to_args(test[0])],
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, timeout=0.1)
out = ps.stdout.decode().strip().split('\n')[0]
except subprocess.TimeoutExpired:
out = 'TIMEOUT'
expected = ' '.join(test[1:])
if out != expected:
bad += 1
total_bad += 1
print('\n'.join((
filled(f'\rTEST #{i + 1}' + (' (TIMEOUT)' if out == 'TIMEOUT' else '')),
filled(f'INPUT: {test[0]}'),
filled(f'OUTPUT: {out}'),
filled(f'EXPECTED: {expected}'),
)), end='\n\n')
else:
good += 1
print(filled(f'\rGOOD: {good}, BAD: {bad}, TOTAL: {len(tests)}'), end='')
if MAX_FAILED and total_bad >= MAX_FAILED:
print("\nToo many failed tests!")
exit(1)
print(end='\r')
return good, bad
def parse_args[T](args: str, names: dict[str, T]) -> tuple[T, ...]:
return tuple(names[name] for name in names if name in args)
def main():
summary: dict[str, tuple[int, int]] = {}
float_types = tuple(iter(FloatType))
test_types = tuple(iter(TestType))
round_types = tuple(iter(RoundType))
for _ in ((),):
if len(sys.argv) < 2: break
float_types = parse_args(sys.argv[1], {
'h': FloatType.HALF,
's': FloatType.SINGLE
})
if len(sys.argv) < 3: break
test_types = parse_args(sys.argv[2], {
'p': TestType.PRN,
'+': TestType.ADD,
'-': TestType.SUB,
'/': TestType.DIV,
'*': TestType.MUL,
'f': TestType.FMA,
'm': TestType.MAD
})
if len(sys.argv) < 4: break
round_types = parse_args(sys.argv[3], {
'0': RoundType.TO_ZERO,
'1': RoundType.TO_NEAREST_EVEN,
'2': RoundType.TO_POSITIVE_INFINITY,
'3': RoundType.TO_NEGATIVE_INFINITY
})
for float_type in float_types:
for test_type in test_types:
for round_type in round_types:
test_path = get_path(float_type, test_type, round_type)
try:
tests = read_tests_from_file(test_path)
except FileNotFoundError:
print(filled(f"\nSkipping missing test file: {test_path}"))
continue
print(f"Running tests: {PRETTY_NAMES[float_type]} / {PRETTY_NAMES[test_type]} / {PRETTY_NAMES[round_type]}")
good, bad = run_tests(BINARY, tests)
key = f"{PRETTY_NAMES[float_type]} / {PRETTY_NAMES[test_type]} / {PRETTY_NAMES[round_type]}"
summary[key] = (good, bad)
print(filled(''))
print(filled('==== TEST SUMMARY ===='))
total_good = total_bad = 0
for category, (good, bad) in summary.items():
print(f"{category:<70} GOOD: {good:>4} | BAD: {bad:>4}")
total_good += good
total_bad += bad
print('\n==== TOTAL ====')
print(f"GOOD: {total_good:>4} | BAD: {total_bad:>4}")
if __name__ == '__main__':
main()