package resolv import ( "bufio" "errors" "fmt" "net" "os" "slices" "strings" "sync" ) const DefaultPath = "/run/dns/resolv.conf" var ( ErrInvalidAddress = errors.New("invalid nameserver address") ErrAlreadyExists = errors.New("nameserver already exists") ErrNotFound = errors.New("nameserver not found") ) type Manager struct { path string mu sync.Mutex } func New(path string) *Manager { if path == "" { path = DefaultPath } return &Manager{path: path} } func (m *Manager) List() ([]string, error) { m.mu.Lock() defer m.mu.Unlock() _, servers, err := m.read() if err != nil { return nil, err } return servers, nil } func (m *Manager) Add(addr string) error { addr = strings.TrimSpace(addr) if !isValidIP(addr) { return fmt.Errorf("%w: %q", ErrInvalidAddress, addr) } m.mu.Lock() defer m.mu.Unlock() lines, servers, err := m.read() if err != nil { return err } if slices.Contains(servers, addr) { return fmt.Errorf("%w: %s", ErrAlreadyExists, addr) } lines = append(lines, "nameserver "+addr) return m.write(lines) } func (m *Manager) Remove(addr string) error { addr = strings.TrimSpace(addr) if !isValidIP(addr) { return fmt.Errorf("%w: %q", ErrInvalidAddress, addr) } m.mu.Lock() defer m.mu.Unlock() lines, _, err := m.read() if err != nil { return err } out := make([]string, 0, len(lines)) removed := false for _, line := range lines { if !removed && parseNameserver(line) == addr { removed = true continue } out = append(out, line) } if !removed { return fmt.Errorf("%w: %s", ErrNotFound, addr) } return m.write(out) } func (m *Manager) read() ([]string, []string, error) { f, err := os.Open(m.path) if err != nil { if os.IsNotExist(err) { return nil, nil, nil } return nil, nil, err } defer f.Close() var lines, servers []string sc := bufio.NewScanner(f) for sc.Scan() { line := sc.Text() lines = append(lines, line) if ns := parseNameserver(line); ns != "" { servers = append(servers, ns) } } if err := sc.Err(); err != nil { return nil, nil, err } return lines, servers, nil } func (m *Manager) write(lines []string) error { dir := dirOf(m.path) tmp, err := os.CreateTemp(dir, ".resolv.conf.tmp.*") if err != nil { return err } tmpName := tmp.Name() cleanup := func() { _ = os.Remove(tmpName) } w := bufio.NewWriter(tmp) for _, line := range lines { if _, err := w.WriteString(line); err != nil { tmp.Close() cleanup() return err } if _, err := w.WriteString("\n"); err != nil { tmp.Close() cleanup() return err } } if err := w.Flush(); err != nil { tmp.Close() cleanup() return err } if err := tmp.Sync(); err != nil { tmp.Close() cleanup() return err } if err := tmp.Close(); err != nil { cleanup() return err } if err := os.Rename(tmpName, m.path); err != nil { cleanup() return err } return nil } func parseNameserver(line string) string { s := strings.TrimSpace(line) if s == "" || strings.HasPrefix(s, "#") || strings.HasPrefix(s, ";") { return "" } fields := strings.Fields(s) if len(fields) < 2 || fields[0] != "nameserver" { return "" } return fields[1] } func isValidIP(s string) bool { return s != "" && net.ParseIP(s) != nil } func dirOf(path string) string { if i := strings.LastIndex(path, "/"); i >= 0 { return path[:i] } return "." }