package main import ( "bytes" "flag" "fmt" "go/ast" "go/token" "go/types" "io" "os" "reflect" "strings" "golang.org/x/tools/go/packages" "golang.org/x/tools/imports" ) const ( defaultMerge = ` if s.%[1]s != %[2]s { s.%[1]s = other.%[1]s } ` defaultValidator = ` if s.%[1]s == %[2]s { return errors.New("missing %[3]s") } ` ipValidator = ` if net.ParseIP(s.%[1]s) == nil { return fmt.Errorf("invalid %[3]s: %%q", s.%[1]s) } ` positiveValidator = ` if s.%[1]s < 1 { return fmt.Errorf("invalid %[3]s: %%q", s.%[1]s) } ` ) var ( structName string packageName string fileName string ) type Validator struct { code string } type Field struct { name string prettyName string defaultValue string validators []string } func NewField(name, prettyName, defaultValue, tag string) *Field { validators := make([]string, 0, 2) switch tag { case "": validators = append(validators, defaultValidator) case "ip": validators = append(validators, defaultValidator, ipValidator) case "positive": validators = append(validators, positiveValidator) case "no": break default: panic(fmt.Sprintf("invalid tag %v", tag)) } return &Field{ name: name, prettyName: prettyName, defaultValue: defaultValue, validators: validators, } } func init() { namePtr := flag.String("name", "", "name of the struct") flag.Parse() if namePtr == nil || *namePtr == "" { fmt.Fprintln(os.Stderr, "Invalid name") os.Exit(1) } structName = *namePtr packageName = os.Getenv("GOPACKAGE") fileName = fmt.Sprintf("%s_gen.go", strings.TrimSuffix(os.Getenv("GOFILE"), ".go")) } func defaultValue(t types.Type) string { b, ok := t.Underlying().(*types.Basic) if !ok { return "nil" } switch b.Kind() { case types.Int, types.Int8, types.Int16, types.Int32, types.Int64, types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64: return "0" case types.String: return `""` case types.Bool: return "false" default: return "nil" } } type Package struct { fields []*Field imports map[string]struct{} } func enumerateFields(root *types.Struct) []*Field { fields := make([]*Field, 0, 16) for i := range root.NumFields() { field := root.Field(i) tag := reflect.StructTag(root.Tag(i)) extraValidators := tag.Get("gen") prettyName := tag.Get("toml") if embed, ok := field.Type().Underlying().(*types.Struct); ok { fields = append(fields, enumerateFields(embed)...) } else { fields = append(fields, NewField(field.Name(), prettyName, defaultValue(field.Type()), extraValidators)) } } return fields } func writeMerge(w io.Writer, name string, fields []*Field) { fmt.Fprintf(w, "func (s *%[1]s) Merge(other %[1]s) {", name) for _, f := range fields { fmt.Fprintf(w, defaultMerge, f.name, f.defaultValue) } fmt.Fprintf(w, "}\n\n") } func writeValidators(w io.Writer, name string, fields []*Field) { fmt.Fprintf(w, "func (s %s) Validate() error {", name) for _, f := range fields { for _, v := range f.validators { fmt.Fprintf(w, v, f.name, f.defaultValue, f.prettyName) } } fmt.Fprintf(w, "\nreturn nil\n}\n\n") } func main() { cfg := &packages.Config{ Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedFiles, } pkgs, _ := packages.Load(cfg, ".") config := *pkgs[0] for _, v := range config.Syntax { ast.Inspect(v, func(node ast.Node) bool { decl, ok := node.(*ast.GenDecl) if !ok || decl.Tok != token.TYPE { return true } for _, spec := range decl.Specs { ts := spec.(*ast.TypeSpec) if ts.Name.String() != structName { continue } t := config.TypesInfo.TypeOf(ts.Type.(*ast.StructType)).(*types.Struct) fields := enumerateFields(t) buf := bytes.NewBuffer(make([]byte, 0)) fmt.Fprintf(buf, "// Code generated by roleconfig; DO NOT EDIT.\n\npackage %s\n", packageName) writeValidators(buf, structName, fields) writeMerge(buf, structName, fields) x, _ := imports.Process(fileName, buf.Bytes(), nil) os.WriteFile(fileName, x, 0644) break } return false }) } }