1
0
Fork 0
hivemind/internal/codegen/roleconfig/main.go

196 lines
4.1 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/token"
"go/types"
"io"
"os"
"reflect"
"strings"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)
const (
defaultMerge = `
if s.%[1]s != %[2]s {
s.%[1]s = other.%[1]s
}
`
defaultValidator = `
if s.%[1]s == %[2]s {
return errors.New("missing %[3]s")
}
`
ipValidator = `
if net.ParseIP(s.%[1]s) == nil {
return fmt.Errorf("invalid %[3]s: %%q", s.%[1]s)
}
`
positiveValidator = `
if s.%[1]s < 1 {
return fmt.Errorf("invalid %[3]s: %%q", s.%[1]s)
}
`
)
var (
structName string
packageName string
fileName string
)
type Validator struct {
code string
}
type Field struct {
name string
prettyName string
defaultValue string
validators []string
}
func NewField(name, prettyName, defaultValue, tag string) *Field {
validators := make([]string, 0, 2)
switch tag {
case "":
validators = append(validators, defaultValidator)
case "ip":
validators = append(validators, defaultValidator, ipValidator)
case "positive":
validators = append(validators, positiveValidator)
case "no":
break
default:
panic(fmt.Sprintf("invalid tag %v", tag))
}
return &Field{
name: name,
prettyName: prettyName,
defaultValue: defaultValue,
validators: validators,
}
}
func init() {
namePtr := flag.String("name", "", "name of the struct")
flag.Parse()
if namePtr == nil || *namePtr == "" {
fmt.Fprintln(os.Stderr, "Invalid name")
os.Exit(1)
}
structName = *namePtr
packageName = os.Getenv("GOPACKAGE")
fileName = fmt.Sprintf("%s_gen.go", strings.TrimSuffix(os.Getenv("GOFILE"), ".go"))
}
func defaultValue(t types.Type) string {
b, ok := t.Underlying().(*types.Basic)
if !ok {
return "nil"
}
switch b.Kind() {
case types.Int, types.Int8, types.Int16, types.Int32, types.Int64,
types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64:
return "0"
case types.String:
return `""`
case types.Bool:
return "false"
default:
return "nil"
}
}
type Package struct {
fields []*Field
imports map[string]struct{}
}
func enumerateFields(root *types.Struct) []*Field {
fields := make([]*Field, 0, 16)
for i := range root.NumFields() {
field := root.Field(i)
tag := reflect.StructTag(root.Tag(i))
extraValidators := tag.Get("gen")
prettyName := tag.Get("toml")
if embed, ok := field.Type().Underlying().(*types.Struct); ok {
fields = append(fields, enumerateFields(embed)...)
} else {
fields = append(fields, NewField(field.Name(), prettyName, defaultValue(field.Type()), extraValidators))
}
}
return fields
}
func writeMerge(w io.Writer, name string, fields []*Field) {
fmt.Fprintf(w, "func (s *%[1]s) Merge(other %[1]s) {", name)
for _, f := range fields {
fmt.Fprintf(w, defaultMerge, f.name, f.defaultValue)
}
fmt.Fprintf(w, "}\n\n")
}
func writeValidators(w io.Writer, name string, fields []*Field) {
fmt.Fprintf(w, "func (s %s) Validate() error {", name)
for _, f := range fields {
for _, v := range f.validators {
fmt.Fprintf(w, v, f.name, f.defaultValue, f.prettyName)
}
}
fmt.Fprintf(w, "\nreturn nil\n}\n\n")
}
func main() {
cfg := &packages.Config{
Mode: packages.NeedName |
packages.NeedTypes |
packages.NeedTypesInfo |
packages.NeedSyntax |
packages.NeedFiles,
}
pkgs, _ := packages.Load(cfg, ".")
config := *pkgs[0]
for _, v := range config.Syntax {
ast.Inspect(v, func(node ast.Node) bool {
decl, ok := node.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
return true
}
for _, spec := range decl.Specs {
ts := spec.(*ast.TypeSpec)
if ts.Name.String() != structName {
continue
}
t := config.TypesInfo.TypeOf(ts.Type.(*ast.StructType)).(*types.Struct)
fields := enumerateFields(t)
buf := bytes.NewBuffer(make([]byte, 0))
fmt.Fprintf(buf, "// Code generated by roleconfig; DO NOT EDIT.\n\npackage %s\n", packageName)
writeValidators(buf, structName, fields)
writeMerge(buf, structName, fields)
x, _ := imports.Process(fileName, buf.Bytes(), nil)
os.WriteFile(fileName, x, 0644)
break
}
return false
})
}
}