196 lines
4.1 KiB
Go
196 lines
4.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/token"
|
|
"go/types"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
"golang.org/x/tools/imports"
|
|
)
|
|
|
|
const (
|
|
defaultMerge = `
|
|
if s.%[1]s != %[2]s {
|
|
s.%[1]s = other.%[1]s
|
|
}
|
|
`
|
|
defaultValidator = `
|
|
if s.%[1]s == %[2]s {
|
|
return errors.New("missing %[3]s")
|
|
}
|
|
`
|
|
ipValidator = `
|
|
if net.ParseIP(s.%[1]s) == nil {
|
|
return fmt.Errorf("invalid %[3]s: %%q", s.%[1]s)
|
|
}
|
|
`
|
|
positiveValidator = `
|
|
if s.%[1]s < 1 {
|
|
return fmt.Errorf("invalid %[3]s: %%q", s.%[1]s)
|
|
}
|
|
`
|
|
)
|
|
|
|
var (
|
|
structName string
|
|
packageName string
|
|
fileName string
|
|
)
|
|
|
|
type Validator struct {
|
|
code string
|
|
}
|
|
|
|
type Field struct {
|
|
name string
|
|
prettyName string
|
|
defaultValue string
|
|
validators []string
|
|
}
|
|
|
|
func NewField(name, prettyName, defaultValue, tag string) *Field {
|
|
validators := make([]string, 0, 2)
|
|
|
|
switch tag {
|
|
case "":
|
|
validators = append(validators, defaultValidator)
|
|
case "ip":
|
|
validators = append(validators, defaultValidator, ipValidator)
|
|
case "positive":
|
|
validators = append(validators, positiveValidator)
|
|
case "no":
|
|
break
|
|
default:
|
|
panic(fmt.Sprintf("invalid tag %v", tag))
|
|
}
|
|
|
|
return &Field{
|
|
name: name,
|
|
prettyName: prettyName,
|
|
defaultValue: defaultValue,
|
|
validators: validators,
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
namePtr := flag.String("name", "", "name of the struct")
|
|
flag.Parse()
|
|
if namePtr == nil || *namePtr == "" {
|
|
fmt.Fprintln(os.Stderr, "Invalid name")
|
|
os.Exit(1)
|
|
}
|
|
structName = *namePtr
|
|
|
|
packageName = os.Getenv("GOPACKAGE")
|
|
|
|
fileName = fmt.Sprintf("%s_gen.go", strings.TrimSuffix(os.Getenv("GOFILE"), ".go"))
|
|
}
|
|
|
|
func defaultValue(t types.Type) string {
|
|
b, ok := t.Underlying().(*types.Basic)
|
|
if !ok {
|
|
return "nil"
|
|
}
|
|
|
|
switch b.Kind() {
|
|
case types.Int, types.Int8, types.Int16, types.Int32, types.Int64,
|
|
types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64:
|
|
return "0"
|
|
case types.String:
|
|
return `""`
|
|
case types.Bool:
|
|
return "false"
|
|
default:
|
|
return "nil"
|
|
}
|
|
}
|
|
|
|
type Package struct {
|
|
fields []*Field
|
|
imports map[string]struct{}
|
|
}
|
|
|
|
func enumerateFields(root *types.Struct) []*Field {
|
|
fields := make([]*Field, 0, 16)
|
|
|
|
for i := range root.NumFields() {
|
|
field := root.Field(i)
|
|
tag := reflect.StructTag(root.Tag(i))
|
|
|
|
extraValidators := tag.Get("gen")
|
|
prettyName := tag.Get("toml")
|
|
if embed, ok := field.Type().Underlying().(*types.Struct); ok {
|
|
fields = append(fields, enumerateFields(embed)...)
|
|
} else {
|
|
fields = append(fields, NewField(field.Name(), prettyName, defaultValue(field.Type()), extraValidators))
|
|
}
|
|
}
|
|
|
|
return fields
|
|
}
|
|
|
|
func writeMerge(w io.Writer, name string, fields []*Field) {
|
|
fmt.Fprintf(w, "func (s *%[1]s) Merge(other %[1]s) {", name)
|
|
for _, f := range fields {
|
|
fmt.Fprintf(w, defaultMerge, f.name, f.defaultValue)
|
|
}
|
|
fmt.Fprintf(w, "}\n\n")
|
|
}
|
|
|
|
func writeValidators(w io.Writer, name string, fields []*Field) {
|
|
fmt.Fprintf(w, "func (s %s) Validate() error {", name)
|
|
for _, f := range fields {
|
|
for _, v := range f.validators {
|
|
fmt.Fprintf(w, v, f.name, f.defaultValue, f.prettyName)
|
|
}
|
|
}
|
|
fmt.Fprintf(w, "\nreturn nil\n}\n\n")
|
|
}
|
|
|
|
func main() {
|
|
cfg := &packages.Config{
|
|
Mode: packages.NeedName |
|
|
packages.NeedTypes |
|
|
packages.NeedTypesInfo |
|
|
packages.NeedSyntax |
|
|
packages.NeedFiles,
|
|
}
|
|
pkgs, _ := packages.Load(cfg, ".")
|
|
config := *pkgs[0]
|
|
for _, v := range config.Syntax {
|
|
ast.Inspect(v, func(node ast.Node) bool {
|
|
decl, ok := node.(*ast.GenDecl)
|
|
if !ok || decl.Tok != token.TYPE {
|
|
return true
|
|
}
|
|
|
|
for _, spec := range decl.Specs {
|
|
ts := spec.(*ast.TypeSpec)
|
|
if ts.Name.String() != structName {
|
|
continue
|
|
}
|
|
|
|
t := config.TypesInfo.TypeOf(ts.Type.(*ast.StructType)).(*types.Struct)
|
|
fields := enumerateFields(t)
|
|
|
|
buf := bytes.NewBuffer(make([]byte, 0))
|
|
fmt.Fprintf(buf, "// Code generated by roleconfig; DO NOT EDIT.\n\npackage %s\n", packageName)
|
|
writeValidators(buf, structName, fields)
|
|
writeMerge(buf, structName, fields)
|
|
x, _ := imports.Process(fileName, buf.Bytes(), nil)
|
|
os.WriteFile(fileName, x, 0644)
|
|
|
|
break
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
}
|