commit ca4bd7d7c7a6f42d069fe3e0ceb1f1838bed97ab Author: Arthur K. Date: Tue May 12 22:54:32 2026 +0300 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ae3c172 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/bin/ diff --git a/cmd/client/main.go b/cmd/client/main.go new file mode 100644 index 0000000..78789ec --- /dev/null +++ b/cmd/client/main.go @@ -0,0 +1,155 @@ +package main + +import ( + "flag" + "fmt" + "io" + "os" + + "github.com/wzray/dns/internal/client" +) + +const defaultServer = "http://localhost:8080" + +func main() { + os.Exit(run(os.Args[1:], os.Stdout, os.Stderr)) +} + +func run(args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("dns-cli", flag.ContinueOnError) + fs.SetOutput(stderr) + server := fs.String("server", envOr("DNS_SERVER_URL", defaultServer), + "Base URL of the dns-server (env: DNS_SERVER_URL)") + fs.Usage = func() { printRootUsage(stderr, fs) } + + if err := fs.Parse(args); err != nil { + return 2 + } + rest := fs.Args() + if len(rest) == 0 { + printRootUsage(stderr, fs) + return 2 + } + + cmd, cmdArgs := rest[0], rest[1:] + c := client.New(*server) + + switch cmd { + case "help", "-h", "--help": + printRootUsage(stdout, fs) + return 0 + case "list", "ls": + return runList(c, cmdArgs, stdout, stderr) + case "add": + return runAdd(c, cmdArgs, stdout, stderr) + case "remove", "rm", "delete": + return runRemove(c, cmdArgs, stdout, stderr) + default: + fmt.Fprintf(stderr, "unknown command: %q\n\n", cmd) + printRootUsage(stderr, fs) + return 2 + } +} + +func runList(c *client.Client, args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("list", flag.ContinueOnError) + fs.SetOutput(stderr) + fs.Usage = func() { + fmt.Fprintln(stderr, "Usage: dns-cli list") + fmt.Fprintln(stderr, "\nList configured DNS nameservers.") + } + if err := fs.Parse(args); err != nil { + return 2 + } + if fs.NArg() != 0 { + fmt.Fprintln(stderr, "list: unexpected arguments") + fs.Usage() + return 2 + } + servers, err := c.List() + if err != nil { + fmt.Fprintf(stderr, "error: %v\n", err) + return 1 + } + if len(servers) == 0 { + fmt.Fprintln(stdout, "(no nameservers configured)") + return 0 + } + for _, s := range servers { + fmt.Fprintln(stdout, s) + } + return 0 +} + +func runAdd(c *client.Client, args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("add", flag.ContinueOnError) + fs.SetOutput(stderr) + fs.Usage = func() { + fmt.Fprintln(stderr, "Usage: dns-cli add
") + fmt.Fprintln(stderr, "\nAdd a DNS nameserver (IPv4 or IPv6).") + } + if err := fs.Parse(args); err != nil { + return 2 + } + if fs.NArg() != 1 { + fmt.Fprintln(stderr, "add: exactly one
argument is required") + fs.Usage() + return 2 + } + addr := fs.Arg(0) + if err := c.Add(addr); err != nil { + fmt.Fprintf(stderr, "error: %v\n", err) + return 1 + } + fmt.Fprintf(stdout, "added: %s\n", addr) + return 0 +} + +func runRemove(c *client.Client, args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("remove", flag.ContinueOnError) + fs.SetOutput(stderr) + fs.Usage = func() { + fmt.Fprintln(stderr, "Usage: dns-cli remove
") + fmt.Fprintln(stderr, "\nRemove a DNS nameserver. Aliases: rm, delete.") + } + if err := fs.Parse(args); err != nil { + return 2 + } + if fs.NArg() != 1 { + fmt.Fprintln(stderr, "remove: exactly one
argument is required") + fs.Usage() + return 2 + } + addr := fs.Arg(0) + if err := c.Remove(addr); err != nil { + fmt.Fprintf(stderr, "error: %v\n", err) + return 1 + } + fmt.Fprintf(stdout, "removed: %s\n", addr) + return 0 +} + +func printRootUsage(w io.Writer, fs *flag.FlagSet) { + fmt.Fprintln(w, "dns-cli — manage DNS nameservers on a remote host via the dns-server REST API.") + fmt.Fprintln(w) + fmt.Fprintln(w, "Usage:") + fmt.Fprintln(w, " dns-cli [global flags] [args]") + fmt.Fprintln(w) + fmt.Fprintln(w, "Commands:") + fmt.Fprintln(w, " list List configured DNS nameservers") + fmt.Fprintln(w, " add
Add a DNS nameserver") + fmt.Fprintln(w, " remove Remove a DNS nameserver (aliases: rm, delete)") + fmt.Fprintln(w, " help Show this help") + fmt.Fprintln(w) + fmt.Fprintln(w, "Global flags:") + fs.PrintDefaults() + fmt.Fprintln(w) + fmt.Fprintln(w, "Run 'dns-cli --help' for command-specific help.") +} + +func envOr(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..fe71c19 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "errors" + "flag" + "log/slog" + "net/http" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "github.com/wzray/dns/internal/resolv" + "github.com/wzray/dns/internal/server" +) + +func main() { + addr := flag.String("addr", ":8080", "HTTP listen address") + logLevel := flag.String("log-level", "info", "Log level: debug, info, warn, error") + logFormat := flag.String("log-format", "text", "Log format: text or json") + flag.Parse() + + log := newLogger(*logLevel, *logFormat) + + resolvPath := resolv.DefaultPath + if err := os.MkdirAll(filepath.Dir(resolvPath), 0o755); err != nil { + log.Error("prepare resolv directory", "err", err) + os.Exit(1) + } + + mgr := resolv.New(resolvPath) + srv := server.New(mgr, log) + + httpSrv := &http.Server{ + Addr: *addr, + Handler: srv.Handler(), + ReadHeaderTimeout: 5 * time.Second, + } + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 1) + go func() { + log.Info("starting server", "addr", *addr) + if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case <-ctx.Done(): + log.Info("shutting down") + case err := <-errCh: + log.Error("server error", "err", err) + os.Exit(1) + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + log.Error("graceful shutdown failed", "err", err) + os.Exit(1) + } +} + +func newLogger(level, format string) *slog.Logger { + var lvl slog.Level + switch level { + case "debug": + lvl = slog.LevelDebug + case "warn": + lvl = slog.LevelWarn + case "error": + lvl = slog.LevelError + default: + lvl = slog.LevelInfo + } + opts := &slog.HandlerOptions{Level: lvl} + var h slog.Handler + if format == "json" { + h = slog.NewJSONHandler(os.Stderr, opts) + } else { + h = slog.NewTextHandler(os.Stderr, opts) + } + return slog.New(h) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9244eb3 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/wzray/dns + +go 1.26.2 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000..ce9ee55 --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,106 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type Client struct { + baseURL string + http *http.Client +} + +func New(baseURL string) *Client { + return &Client{ + baseURL: baseURL, + http: &http.Client{Timeout: 10 * time.Second}, + } +} + +type listResponse struct { + Servers []string `json:"servers"` +} + +type addRequest struct { + Address string `json:"address"` +} + +type errorResponse struct { + Error string `json:"error"` +} + +func (c *Client) List() ([]string, error) { + req, err := http.NewRequest(http.MethodGet, c.baseURL+"/dns", nil) + if err != nil { + return nil, err + } + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, decodeError(resp) + } + var lr listResponse + if err := json.NewDecoder(resp.Body).Decode(&lr); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + return lr.Servers, nil +} + +func (c *Client) Add(addr string) error { + body, err := json.Marshal(addRequest{Address: addr}) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPost, c.baseURL+"/dns", bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return decodeError(resp) + } + return nil +} + +func (c *Client) Remove(addr string) error { + u := c.baseURL + "/dns/" + url.PathEscape(addr) + req, err := http.NewRequest(http.MethodDelete, u, nil) + if err != nil { + return err + } + resp, err := c.http.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + return decodeError(resp) + } + return nil +} + +func decodeError(resp *http.Response) error { + body, _ := io.ReadAll(resp.Body) + var er errorResponse + if json.Unmarshal(body, &er) == nil && er.Error != "" { + return fmt.Errorf("server returned %s: %s", resp.Status, er.Error) + } + return fmt.Errorf("server returned %s: %s", resp.Status, string(body)) +} diff --git a/internal/resolv/resolv.go b/internal/resolv/resolv.go new file mode 100644 index 0000000..d1c42a3 --- /dev/null +++ b/internal/resolv/resolv.go @@ -0,0 +1,183 @@ +package resolv + +import ( + "bufio" + "errors" + "fmt" + "net" + "os" + "slices" + "strings" + "sync" +) + +const DefaultPath = "/run/dns/resolv.conf" + +var ( + ErrInvalidAddress = errors.New("invalid nameserver address") + ErrAlreadyExists = errors.New("nameserver already exists") + ErrNotFound = errors.New("nameserver not found") +) + +type Manager struct { + path string + mu sync.Mutex +} + +func New(path string) *Manager { + if path == "" { + path = DefaultPath + } + return &Manager{path: path} +} + +func (m *Manager) List() ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + _, servers, err := m.read() + if err != nil { + return nil, err + } + return servers, nil +} + +func (m *Manager) Add(addr string) error { + addr = strings.TrimSpace(addr) + if !isValidIP(addr) { + return fmt.Errorf("%w: %q", ErrInvalidAddress, addr) + } + + m.mu.Lock() + defer m.mu.Unlock() + + lines, servers, err := m.read() + if err != nil { + return err + } + if slices.Contains(servers, addr) { + return fmt.Errorf("%w: %s", ErrAlreadyExists, addr) + } + lines = append(lines, "nameserver "+addr) + return m.write(lines) +} + +func (m *Manager) Remove(addr string) error { + addr = strings.TrimSpace(addr) + if !isValidIP(addr) { + return fmt.Errorf("%w: %q", ErrInvalidAddress, addr) + } + + m.mu.Lock() + defer m.mu.Unlock() + + lines, _, err := m.read() + if err != nil { + return err + } + + out := make([]string, 0, len(lines)) + removed := false + for _, line := range lines { + if !removed && parseNameserver(line) == addr { + removed = true + continue + } + out = append(out, line) + } + if !removed { + return fmt.Errorf("%w: %s", ErrNotFound, addr) + } + return m.write(out) +} + +func (m *Manager) read() ([]string, []string, error) { + f, err := os.Open(m.path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil, nil + } + return nil, nil, err + } + defer f.Close() + + var lines, servers []string + sc := bufio.NewScanner(f) + for sc.Scan() { + line := sc.Text() + lines = append(lines, line) + if ns := parseNameserver(line); ns != "" { + servers = append(servers, ns) + } + } + if err := sc.Err(); err != nil { + return nil, nil, err + } + return lines, servers, nil +} + +func (m *Manager) write(lines []string) error { + dir := dirOf(m.path) + tmp, err := os.CreateTemp(dir, ".resolv.conf.tmp.*") + if err != nil { + return err + } + tmpName := tmp.Name() + cleanup := func() { _ = os.Remove(tmpName) } + + w := bufio.NewWriter(tmp) + for _, line := range lines { + if _, err := w.WriteString(line); err != nil { + tmp.Close() + cleanup() + return err + } + if _, err := w.WriteString("\n"); err != nil { + tmp.Close() + cleanup() + return err + } + } + if err := w.Flush(); err != nil { + tmp.Close() + cleanup() + return err + } + if err := tmp.Sync(); err != nil { + tmp.Close() + cleanup() + return err + } + if err := tmp.Close(); err != nil { + cleanup() + return err + } + if err := os.Rename(tmpName, m.path); err != nil { + cleanup() + return err + } + return nil +} + +func parseNameserver(line string) string { + s := strings.TrimSpace(line) + if s == "" || strings.HasPrefix(s, "#") || strings.HasPrefix(s, ";") { + return "" + } + fields := strings.Fields(s) + if len(fields) < 2 || fields[0] != "nameserver" { + return "" + } + return fields[1] +} + +func isValidIP(s string) bool { + return s != "" && net.ParseIP(s) != nil +} + +func dirOf(path string) string { + if i := strings.LastIndex(path, "/"); i >= 0 { + return path[:i] + } + return "." +} diff --git a/internal/resolv/resolv_test.go b/internal/resolv/resolv_test.go new file mode 100644 index 0000000..d3ce010 --- /dev/null +++ b/internal/resolv/resolv_test.go @@ -0,0 +1,158 @@ +package resolv + +import ( + "errors" + "os" + "path/filepath" + "reflect" + "sync" + "testing" +) + +func newTestManager(t *testing.T, initial string) (*Manager, string) { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "resolv.conf") + if initial != "" { + if err := os.WriteFile(path, []byte(initial), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + } + return New(path), path +} + +func readFile(t *testing.T, path string) string { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read: %v", err) + } + return string(b) +} + +func TestList_EmptyMissingFile(t *testing.T) { + m, _ := newTestManager(t, "") + got, err := m.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(got) != 0 { + t.Fatalf("want empty, got %v", got) + } +} + +func TestList_ParsesNameserversIgnoresOther(t *testing.T) { + const initial = `# comment +; another comment +search example.com +nameserver 1.1.1.1 +options rotate +nameserver 8.8.8.8 +nameserver 2001:4860:4860::8888 +` + m, _ := newTestManager(t, initial) + got, err := m.List() + if err != nil { + t.Fatalf("List: %v", err) + } + want := []string{"1.1.1.1", "8.8.8.8", "2001:4860:4860::8888"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("want %v, got %v", want, got) + } +} + +func TestAdd_AppendsAndPersists(t *testing.T) { + m, path := newTestManager(t, "search example.com\nnameserver 1.1.1.1\n") + if err := m.Add("8.8.8.8"); err != nil { + t.Fatalf("Add: %v", err) + } + got, _ := m.List() + want := []string{"1.1.1.1", "8.8.8.8"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("list want %v, got %v", want, got) + } + content := readFile(t, path) + if want := "search example.com\nnameserver 1.1.1.1\nnameserver 8.8.8.8\n"; content != want { + t.Fatalf("file mismatch:\nwant: %q\ngot: %q", want, content) + } +} + +func TestAdd_RejectsInvalid(t *testing.T) { + m, _ := newTestManager(t, "") + for _, addr := range []string{"", "not-an-ip", "999.999.999.999", "1.1.1"} { + if err := m.Add(addr); !errors.Is(err, ErrInvalidAddress) { + t.Fatalf("Add(%q): want ErrInvalidAddress, got %v", addr, err) + } + } +} + +func TestAdd_DuplicateRejected(t *testing.T) { + m, _ := newTestManager(t, "nameserver 1.1.1.1\n") + if err := m.Add("1.1.1.1"); !errors.Is(err, ErrAlreadyExists) { + t.Fatalf("want ErrAlreadyExists, got %v", err) + } +} + +func TestRemove_DeletesEntry(t *testing.T) { + const initial = "search example.com\nnameserver 1.1.1.1\nnameserver 8.8.8.8\n" + m, path := newTestManager(t, initial) + if err := m.Remove("1.1.1.1"); err != nil { + t.Fatalf("Remove: %v", err) + } + got, _ := m.List() + if want := []string{"8.8.8.8"}; !reflect.DeepEqual(got, want) { + t.Fatalf("want %v, got %v", want, got) + } + content := readFile(t, path) + if want := "search example.com\nnameserver 8.8.8.8\n"; content != want { + t.Fatalf("file mismatch:\nwant: %q\ngot: %q", want, content) + } +} + +func TestRemove_NotFound(t *testing.T) { + m, _ := newTestManager(t, "nameserver 1.1.1.1\n") + if err := m.Remove("8.8.8.8"); !errors.Is(err, ErrNotFound) { + t.Fatalf("want ErrNotFound, got %v", err) + } +} + +func TestRemove_InvalidAddress(t *testing.T) { + m, _ := newTestManager(t, "") + if err := m.Remove("nope"); !errors.Is(err, ErrInvalidAddress) { + t.Fatalf("want ErrInvalidAddress, got %v", err) + } +} + +func TestConcurrentAddsConsistent(t *testing.T) { + m, _ := newTestManager(t, "") + addrs := []string{ + "1.1.1.1", "8.8.8.8", "9.9.9.9", "8.8.4.4", + "1.0.0.1", "208.67.222.222", "208.67.220.220", "64.6.64.6", + } + var wg sync.WaitGroup + for _, a := range addrs { + wg.Add(1) + go func(a string) { + defer wg.Done() + if err := m.Add(a); err != nil { + t.Errorf("Add(%s): %v", a, err) + } + }(a) + } + wg.Wait() + + got, err := m.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(got) != len(addrs) { + t.Fatalf("want %d entries, got %d (%v)", len(addrs), len(got), got) + } + seen := make(map[string]bool, len(got)) + for _, s := range got { + if seen[s] { + t.Fatalf("duplicate %s in %v", s, got) + } + seen[s] = true + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..5a484f6 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,146 @@ +package server + +import ( + "encoding/json" + "errors" + "log/slog" + "net/http" + "time" + + "github.com/wzray/dns/internal/resolv" +) + +type Server struct { + mgr *resolv.Manager + log *slog.Logger + mux *http.ServeMux +} + +func New(mgr *resolv.Manager, log *slog.Logger) *Server { + s := &Server{mgr: mgr, log: log, mux: http.NewServeMux()} + s.routes() + return s +} + +func (s *Server) Handler() http.Handler { + return s.withLogging(s.mux) +} + +func (s *Server) routes() { + s.mux.HandleFunc("GET /health", s.handleHealth) + s.mux.HandleFunc("GET /dns", s.handleList) + s.mux.HandleFunc("POST /dns", s.handleAdd) + s.mux.HandleFunc("DELETE /dns/{address}", s.handleRemove) +} + +type listResponse struct { + Servers []string `json:"servers"` +} + +type addRequest struct { + Address string `json:"address"` +} + +type errorResponse struct { + Error string `json:"error"` +} + +func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +func (s *Server) handleList(w http.ResponseWriter, r *http.Request) { + servers, err := s.mgr.List() + if err != nil { + s.log.Error("list nameservers", "err", err) + writeError(w, http.StatusInternalServerError, "failed to read nameservers") + return + } + if servers == nil { + servers = []string{} + } + writeJSON(w, http.StatusOK, listResponse{Servers: servers}) +} + +func (s *Server) handleAdd(w http.ResponseWriter, r *http.Request) { + var req addRequest + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body") + return + } + if req.Address == "" { + writeError(w, http.StatusBadRequest, "address is required") + return + } + + err := s.mgr.Add(req.Address) + switch { + case err == nil: + writeJSON(w, http.StatusCreated, map[string]string{"address": req.Address}) + case errors.Is(err, resolv.ErrInvalidAddress): + writeError(w, http.StatusBadRequest, err.Error()) + case errors.Is(err, resolv.ErrAlreadyExists): + writeError(w, http.StatusConflict, err.Error()) + default: + s.log.Error("add nameserver", "addr", req.Address, "err", err) + writeError(w, http.StatusInternalServerError, "failed to add nameserver") + } +} + +func (s *Server) handleRemove(w http.ResponseWriter, r *http.Request) { + addr := r.PathValue("address") + if addr == "" { + writeError(w, http.StatusBadRequest, "address is required") + return + } + + err := s.mgr.Remove(addr) + switch { + case err == nil: + w.WriteHeader(http.StatusNoContent) + case errors.Is(err, resolv.ErrInvalidAddress): + writeError(w, http.StatusBadRequest, err.Error()) + case errors.Is(err, resolv.ErrNotFound): + writeError(w, http.StatusNotFound, err.Error()) + default: + s.log.Error("remove nameserver", "addr", addr, "err", err) + writeError(w, http.StatusInternalServerError, "failed to remove nameserver") + } +} + +func writeJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func writeError(w http.ResponseWriter, status int, msg string) { + writeJSON(w, status, errorResponse{Error: msg}) +} + +type statusRecorder struct { + http.ResponseWriter + status int +} + +func (r *statusRecorder) WriteHeader(code int) { + r.status = code + r.ResponseWriter.WriteHeader(code) +} + +func (s *Server) withLogging(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rec, r) + s.log.Info("http request", + "method", r.Method, + "path", r.URL.Path, + "status", rec.status, + "duration_ms", time.Since(start).Milliseconds(), + "remote", r.RemoteAddr, + ) + }) +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..c7b24db --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,103 @@ +package server + +import ( + "bytes" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/wzray/dns/internal/resolv" +) + +func newTestServer(t *testing.T) (*Server, string) { + t.Helper() + path := filepath.Join(t.TempDir(), "resolv.conf") + mgr := resolv.New(path) + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + return New(mgr, log), path +} + +func do(t *testing.T, h http.Handler, method, target string, body any) *httptest.ResponseRecorder { + t.Helper() + var rdr io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal: %v", err) + } + rdr = bytes.NewReader(b) + } + req := httptest.NewRequest(method, target, rdr) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + return rec +} + +func TestServer_AddListRemove(t *testing.T) { + s, _ := newTestServer(t) + h := s.Handler() + + rec := do(t, h, "GET", "/dns", nil) + if rec.Code != http.StatusOK { + t.Fatalf("list: status %d", rec.Code) + } + var lr listResponse + if err := json.Unmarshal(rec.Body.Bytes(), &lr); err != nil { + t.Fatalf("decode: %v", err) + } + if len(lr.Servers) != 0 { + t.Fatalf("want empty, got %v", lr.Servers) + } + + rec = do(t, h, "POST", "/dns", addRequest{Address: "1.1.1.1"}) + if rec.Code != http.StatusCreated { + t.Fatalf("add: status %d body %s", rec.Code, rec.Body.String()) + } + + rec = do(t, h, "POST", "/dns", addRequest{Address: "1.1.1.1"}) + if rec.Code != http.StatusConflict { + t.Fatalf("dup: status %d", rec.Code) + } + + rec = do(t, h, "POST", "/dns", addRequest{Address: "not-an-ip"}) + if rec.Code != http.StatusBadRequest { + t.Fatalf("invalid: status %d", rec.Code) + } + + rec = do(t, h, "DELETE", "/dns/1.1.1.1", nil) + if rec.Code != http.StatusNoContent { + t.Fatalf("remove: status %d", rec.Code) + } + + rec = do(t, h, "DELETE", "/dns/1.1.1.1", nil) + if rec.Code != http.StatusNotFound { + t.Fatalf("remove missing: status %d", rec.Code) + } +} + +func TestServer_BadJSON(t *testing.T) { + s, _ := newTestServer(t) + h := s.Handler() + req := httptest.NewRequest("POST", "/dns", bytes.NewReader([]byte("{not json"))) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status %d", rec.Code) + } +} + +func TestServer_HealthAndUnknownRoute(t *testing.T) { + s, _ := newTestServer(t) + h := s.Handler() + + if rec := do(t, h, "GET", "/health", nil); rec.Code != http.StatusOK { + t.Fatalf("health status %d", rec.Code) + } + if rec := do(t, h, "GET", "/unknown", nil); rec.Code != http.StatusNotFound { + t.Fatalf("unknown route status %d", rec.Code) + } +}