diff --git a/float/main.go b/float/main.go new file mode 100644 index 0000000..4a111e7 --- /dev/null +++ b/float/main.go @@ -0,0 +1,351 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +const ( + MAX_FAILED = 0 + BINARY = "./softfloat/softfloat" + TEST_DIR = "./extras/tests" + BATCH_SIZE = 16 +) + +type FloatType int +type RoundType int +type TestType int + +const ( + Half FloatType = iota + Single +) + +const ( + Print TestType = iota + Add + Subtract + Multiply + Divide + FMA + MAD +) + +const ( + ToZero RoundType = iota + ToNearestEven + ToPositiveInfinity + ToNegativeInfinity +) + +var ( + floatTypeName = map[FloatType]string{ + Half: "half", + Single: "single", + } + + floatTypePrettyName = map[FloatType]string{ + Half: "Half Precision", + Single: "Single Precision", + } + + testTypeName = map[TestType]string{ + Print: "prn", + Add: "add", + Subtract: "sub", + Multiply: "mul", + Divide: "div", + FMA: "fma", + MAD: "mad", + } + + testTypePrettyName = map[TestType]string{ + Print: "Print", + Add: "Addition", + Subtract: "Subtraction", + Multiply: "Division", + Divide: "Multiplication", + FMA: "Fused Multiply-Add", + MAD: "Multiply-Add", + } + + roundTypeName = map[RoundType]string{ + ToZero: "to_zero", + ToNearestEven: "to_nearest_even", + ToPositiveInfinity: "to_positive_infinity", + ToNegativeInfinity: "to_negative_infinity", + } + + roundTypePrettyName = map[RoundType]string{ + ToZero: "Round Toward Zero", + ToNearestEven: "Round to Nearest Even", + ToPositiveInfinity: "Round Toward +∞", + ToNegativeInfinity: "Round Toward −∞", + } +) + +type Test struct { + Index uint64 + Input string + Expected string +} + +type TestResult struct { + Output string + Test +} + +func (f FloatType) String() string { + return floatTypeName[f] +} + +func (f FloatType) PrettyName() string { + return floatTypePrettyName[f] +} + +func (t TestType) String() string { + return testTypeName[t] +} + +func (t TestType) PrettyName() string { + return testTypePrettyName[t] +} + +func (r RoundType) String() string { + return roundTypeName[r] +} + +func (r RoundType) PrettyName() string { + return roundTypePrettyName[r] +} + +func batched[T any](data []T, n int) func(func(int, []T) bool) { + return func(yield func(int, []T) bool) { + l := len(data) + for i := 0; i < l/n+1; i++ { + if !yield(i, data[i*n:min(i*n+n, l)]) { + return + } + } + } +} + +func zip[T any, U any](first []T, second []U) func(func(int, T, U) bool) { + return func(yield func(int, T, U) bool) { + for i := 0; i < min(len(first), len(second)); i++ { + if !yield(i, first[i], second[i]) { + return + } + } + } +} + +var termWidth = func() int{ + cmd := exec.Command("tput", "cols") + cmd.Stdin = os.Stdin + stdout, err := cmd.Output() + if err != nil { + return 80 + } + out, err := strconv.Atoi(strings.Replace(string(stdout), "\n", "", 1)) + if err != nil { + return 80 + } + return out +}() + +func filled(v string) string { + return v + strings.Repeat(" ", max(termWidth - utf8.RuneCountInString(v), 0)) +} + +func filledn(v string, n int) string { + return v + strings.Repeat(" ", max(n - utf8.RuneCountInString(v), 0)) +} + +func getPath(floatType FloatType, testType TestType, roundType RoundType) string { + return fmt.Sprintf("%v/%v/%v/%v.tsv", TEST_DIR, floatType.String(), testType.String(), roundType.String()) +} + +func runTest(ctx context.Context, test Test) TestResult { + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + tmp_out, err := exec.CommandContext(ctx, BINARY, strings.Split(test.Input, " ")...).Output() + + var output string + if err != nil { + output = err.Error() + } else if err := ctx.Err(); err != nil { + output = fmt.Sprintf("context error: %v", err.Error()) + } else { + output = strings.Split(strings.TrimSpace(string(tmp_out)), "\n")[0] + } + + return TestResult{ + output, + test, + } +} + +func runTests(ctx context.Context, tests []Test) (good int, bad int) { + good, bad = 0, 0 + + for _, batch := range batched(tests, BATCH_SIZE) { + testResults := make(chan TestResult) + + for _, v := range batch { + go func(t Test) { + testResults <- runTest(ctx, t) + }(v) + } + + for range len(batch) { + result, ok := <-testResults + if !ok { + break + } + if result.Output != result.Expected { + bad++ + fmt.Print(filled(fmt.Sprintf("\r%v\nINPUT: %v\nOUTPUT: %v\nEXPECTED: %v\n\n", + filled(fmt.Sprintf("TEST #%v", result.Index)), + result.Input, + result.Output, + result.Expected, + ))) + } else { + good++ + } + + fmt.Print(filled(fmt.Sprintf("\rGOOD: %v, BAD: %v, TOTAL: %v", good, bad, len(tests)))) + } + } + + return +} + +func readTests(path string) ([]Test, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("error reading test file: %v", err) + } + defer file.Close() + + s := bufio.NewScanner(file) + var tests []Test + var idx uint64 = 0 + for s.Scan() { + parts := strings.Split(s.Text(), "\t") + if len(parts) < 2 { + continue + } + + tests = append(tests, Test{ + Index: idx, + Input: parts[0], + Expected: strings.Join(parts[1:], " "), + }) + + idx++ + } + + if err := s.Err(); err != nil { + return nil, fmt.Errorf("error reading test file: %v", err) + } + + return tests, nil +} + +func parseArgs[T FloatType | TestType | RoundType](arg string, mapping map[rune]T) []T { + var res []T + for _, ch := range arg { + if v, ok := mapping[ch]; ok { + res = append(res, v) + } + } + return res +} + +func main() { + totals := make(map[string][2]int) + + floatTypes := []FloatType{Half, Single} + testTypes := []TestType{Print, Add, Subtract, Divide, Multiply, FMA, MAD} + roundTypes := []RoundType{ToZero, ToNearestEven, ToPositiveInfinity, ToNegativeInfinity} + + if len(os.Args) > 1 { + floatTypes = parseArgs(os.Args[1], map[rune]FloatType{ + 'h': Half, + 's': Single, + }) + } + + if len(os.Args) > 2 { + testTypes = parseArgs(os.Args[2], map[rune]TestType{ + 'p': Print, + '+': Add, + '-': Subtract, + '/': Divide, + '*': Multiply, + 'f': FMA, + 'm': MAD, + }) + } + + if len(os.Args) > 3 { + roundTypes = parseArgs(os.Args[3], map[rune]RoundType{ + '0': ToZero, + '1': ToNearestEven, + '2': ToPositiveInfinity, + '3': ToNegativeInfinity, + }) + } + + ctx := context.Background() + + for _, floatType := range floatTypes { + for _, testType := range testTypes { + for _, roundType := range roundTypes { + testPath := getPath(floatType, testType, roundType) + tests, err := readTests(testPath) + if err != nil { + fmt.Printf("error reading tests from %v: %v", testPath, err.Error()) + continue + } + name := fmt.Sprintf( + "%v / %v / %v", + floatType.PrettyName(), + testType.PrettyName(), + roundType.PrettyName(), + ) + fmt.Println("\rRunning tests:", name) + good, bad := runTests(ctx, tests) + totals[name] = [2]int{good, bad} + } + } + } + + maxa, maxb, maxc := 0, 0, 0 + total_good, total_bad := 0, 0 + + for k, v := range totals { + maxa = max(maxa, utf8.RuneCountInString(k)) + maxb = max(maxb, len(fmt.Sprint(v[0]))) + maxc = max(maxc, len(fmt.Sprint(v[1]))) + total_bad += v[1] + total_good += v[0] + } + fmt.Println(filled("\r==== TEST RESULTS ====")) + + for k, v := range totals { + fmt.Printf("%v | GOOD: %v | BAD: %v\n", filledn(k, maxa), filledn(fmt.Sprint(v[0]), maxb), filledn(fmt.Sprint(v[1]), maxb)) + } + + fmt.Printf("\n\n==== TOTAL ====\nGOOD: %v\nBAD: %v\n", total_good, total_bad) +}