package main import ( "bufio" "context" "fmt" "os" "os/exec" "strconv" "strings" "time" "unicode/utf8" ) const ( MAX_FAILED = 0 BINARY = "./softfloat/softfloat" TEST_DIR = "./extras/tests" BATCH_SIZE = 16 ) type FloatType int type RoundType int type TestType int const ( Half FloatType = iota Single ) const ( Print TestType = iota Add Subtract Multiply Divide FMA MAD ) const ( ToZero RoundType = iota ToNearestEven ToPositiveInfinity ToNegativeInfinity ) var ( floatTypeName = map[FloatType]string{ Half: "half", Single: "single", } floatTypePrettyName = map[FloatType]string{ Half: "Half Precision", Single: "Single Precision", } testTypeName = map[TestType]string{ Print: "prn", Add: "add", Subtract: "sub", Multiply: "mul", Divide: "div", FMA: "fma", MAD: "mad", } testTypePrettyName = map[TestType]string{ Print: "Print", Add: "Addition", Subtract: "Subtraction", Multiply: "Division", Divide: "Multiplication", FMA: "Fused Multiply-Add", MAD: "Multiply-Add", } roundTypeName = map[RoundType]string{ ToZero: "to_zero", ToNearestEven: "to_nearest_even", ToPositiveInfinity: "to_positive_infinity", ToNegativeInfinity: "to_negative_infinity", } roundTypePrettyName = map[RoundType]string{ ToZero: "Round Toward Zero", ToNearestEven: "Round to Nearest Even", ToPositiveInfinity: "Round Toward +∞", ToNegativeInfinity: "Round Toward −∞", } ) type Test struct { Index uint64 Input string Expected string } type TestResult struct { Output string Test } func (f FloatType) String() string { return floatTypeName[f] } func (f FloatType) PrettyName() string { return floatTypePrettyName[f] } func (t TestType) String() string { return testTypeName[t] } func (t TestType) PrettyName() string { return testTypePrettyName[t] } func (r RoundType) String() string { return roundTypeName[r] } func (r RoundType) PrettyName() string { return roundTypePrettyName[r] } func batched[T any](data []T, n int) func(func(int, []T) bool) { return func(yield func(int, []T) bool) { l := len(data) for i := 0; i < l/n+1; i++ { if !yield(i, data[i*n:min(i*n+n, l)]) { return } } } } func zip[T any, U any](first []T, second []U) func(func(int, T, U) bool) { return func(yield func(int, T, U) bool) { for i := 0; i < min(len(first), len(second)); i++ { if !yield(i, first[i], second[i]) { return } } } } var termWidth = func() int{ cmd := exec.Command("tput", "cols") cmd.Stdin = os.Stdin stdout, err := cmd.Output() if err != nil { return 80 } out, err := strconv.Atoi(strings.Replace(string(stdout), "\n", "", 1)) if err != nil { return 80 } return out }() func filled(v string) string { return v + strings.Repeat(" ", max(termWidth - utf8.RuneCountInString(v), 0)) } func filledn(v string, n int) string { return v + strings.Repeat(" ", max(n - utf8.RuneCountInString(v), 0)) } func getPath(floatType FloatType, testType TestType, roundType RoundType) string { return fmt.Sprintf("%v/%v/%v/%v.tsv", TEST_DIR, floatType.String(), testType.String(), roundType.String()) } func runTest(ctx context.Context, test Test) TestResult { ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) defer cancel() tmp_out, err := exec.CommandContext(ctx, BINARY, strings.Split(test.Input, " ")...).Output() var output string if err != nil { output = err.Error() } else if err := ctx.Err(); err != nil { output = fmt.Sprintf("context error: %v", err.Error()) } else { output = strings.Split(strings.TrimSpace(string(tmp_out)), "\n")[0] } return TestResult{ output, test, } } func runTests(ctx context.Context, tests []Test) (good int, bad int) { good, bad = 0, 0 for _, batch := range batched(tests, BATCH_SIZE) { testResults := make(chan TestResult) for _, v := range batch { go func(t Test) { testResults <- runTest(ctx, t) }(v) } for range len(batch) { result, ok := <-testResults if !ok { break } if result.Output != result.Expected { bad++ fmt.Print(filled(fmt.Sprintf("\r%v\nINPUT: %v\nOUTPUT: %v\nEXPECTED: %v\n\n", filled(fmt.Sprintf("TEST #%v", result.Index)), result.Input, result.Output, result.Expected, ))) } else { good++ } fmt.Print(filled(fmt.Sprintf("\rGOOD: %v, BAD: %v, TOTAL: %v", good, bad, len(tests)))) } } return } func readTests(path string) ([]Test, error) { file, err := os.Open(path) if err != nil { return nil, fmt.Errorf("error reading test file: %v", err) } defer file.Close() s := bufio.NewScanner(file) var tests []Test var idx uint64 = 0 for s.Scan() { parts := strings.Split(s.Text(), "\t") if len(parts) < 2 { continue } tests = append(tests, Test{ Index: idx, Input: parts[0], Expected: strings.Join(parts[1:], " "), }) idx++ } if err := s.Err(); err != nil { return nil, fmt.Errorf("error reading test file: %v", err) } return tests, nil } func parseArgs[T FloatType | TestType | RoundType](arg string, mapping map[rune]T) []T { var res []T for _, ch := range arg { if v, ok := mapping[ch]; ok { res = append(res, v) } } return res } func main() { totals := make(map[string][2]int) floatTypes := []FloatType{Half, Single} testTypes := []TestType{Print, Add, Subtract, Divide, Multiply, FMA, MAD} roundTypes := []RoundType{ToZero, ToNearestEven, ToPositiveInfinity, ToNegativeInfinity} if len(os.Args) > 1 { floatTypes = parseArgs(os.Args[1], map[rune]FloatType{ 'h': Half, 's': Single, }) } if len(os.Args) > 2 { testTypes = parseArgs(os.Args[2], map[rune]TestType{ 'p': Print, '+': Add, '-': Subtract, '/': Divide, '*': Multiply, 'f': FMA, 'm': MAD, }) } if len(os.Args) > 3 { roundTypes = parseArgs(os.Args[3], map[rune]RoundType{ '0': ToZero, '1': ToNearestEven, '2': ToPositiveInfinity, '3': ToNegativeInfinity, }) } ctx := context.Background() for _, floatType := range floatTypes { for _, testType := range testTypes { for _, roundType := range roundTypes { testPath := getPath(floatType, testType, roundType) tests, err := readTests(testPath) if err != nil { fmt.Printf("error reading tests from %v: %v", testPath, err.Error()) continue } name := fmt.Sprintf( "%v / %v / %v", floatType.PrettyName(), testType.PrettyName(), roundType.PrettyName(), ) fmt.Println("\rRunning tests:", name) good, bad := runTests(ctx, tests) totals[name] = [2]int{good, bad} } } } maxa, maxb, maxc := 0, 0, 0 total_good, total_bad := 0, 0 for k, v := range totals { maxa = max(maxa, utf8.RuneCountInString(k)) maxb = max(maxb, len(fmt.Sprint(v[0]))) maxc = max(maxc, len(fmt.Sprint(v[1]))) total_bad += v[1] total_good += v[0] } fmt.Println(filled("\r==== TEST RESULTS ====")) for k, v := range totals { fmt.Printf("%v | GOOD: %v | BAD: %v\n", filledn(k, maxa), filledn(fmt.Sprint(v[0]), maxb), filledn(fmt.Sprint(v[1]), maxb)) } fmt.Printf("\n\n==== TOTAL ====\nGOOD: %v\nBAD: %v\n", total_good, total_bad) }