diff --git a/Gopkg.lock b/Gopkg.lock index c2a3bbbc8..f6718f283 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -1266,7 +1266,7 @@ "roundrobin", "utils" ] - revision = "f0cbb9d6b797d92d168b95b5c443a31dfa67ccd0" + revision = "a3ed5f65204f4ffccbb56d58cec466cdb7ab730b" [[projects]] name = "github.com/vulcand/predicate" diff --git a/provider/docker/config.go b/provider/docker/config.go index 002d1f115..546e34c1e 100644 --- a/provider/docker/config.go +++ b/provider/docker/config.go @@ -33,7 +33,7 @@ func (p *Provider) buildConfigurationV2(containersInspected []dockerData) *types "getDomain": label.GetFuncString(label.TraefikDomain, p.Domain), // Backend functions - "getIPAddress": p.getIPAddress, + "getIPAddress": p.getDeprecatedIPAddress, // TODO: Should we expose getIPPort instead? "getServers": p.getServers, "getMaxConn": label.GetMaxConn, "getHealthCheck": label.GetHealthCheck, @@ -235,17 +235,6 @@ func (p Provider) getIPAddress(container dockerData) string { return p.getIPAddress(parseContainer(containerInspected)) } - if p.UseBindPortIP { - port := getPortV1(container) - for netPort, portBindings := range container.NetworkSettings.Ports { - if string(netPort) == port+"/TCP" || string(netPort) == port+"/UDP" { - for _, p := range portBindings { - return p.HostIP - } - } - } - } - for _, network := range container.NetworkSettings.Networks { return network.Addr } @@ -254,6 +243,16 @@ func (p Provider) getIPAddress(container dockerData) string { return "" } +// Deprecated: Please use getIPPort instead +func (p *Provider) getDeprecatedIPAddress(container dockerData) string { + ip, _, err := p.getIPPort(container) + if err != nil { + log.Warn(err) + return "" + } + return ip +} + // Escape beginning slash "/", convert all others to dash "-", and convert underscores "_" to dash "-" func getSubDomain(name string) string { return strings.Replace(strings.Replace(strings.TrimPrefix(name, "/"), "/", "-", -1), "_", "-", -1) @@ -322,13 +321,53 @@ func getPort(container dockerData) string { return "" } +func (p *Provider) getPortBinding(container dockerData) (*nat.PortBinding, error) { + port := getPort(container) + for netPort, portBindings := range container.NetworkSettings.Ports { + if strings.EqualFold(string(netPort), port+"/TCP") || strings.EqualFold(string(netPort), port+"/UDP") { + for _, p := range portBindings { + return &p, nil + } + } + } + + return nil, fmt.Errorf("unable to find the external IP:Port for the container %q", container.Name) +} + +func (p *Provider) getIPPort(container dockerData) (string, string, error) { + var ip, port string + + if p.UseBindPortIP { + portBinding, err := p.getPortBinding(container) + if err != nil { + return "", "", fmt.Errorf("unable to find a binding for the container %q: ignoring server", container.Name) + } + + if portBinding.HostIP == "0.0.0.0" { + return "", "", fmt.Errorf("cannot determine the IP address (got 0.0.0.0) for the container %q: ignoring server", container.Name) + } + + ip = portBinding.HostIP + port = portBinding.HostPort + + } else { + ip = p.getIPAddress(container) + port = getPort(container) + } + + if len(ip) == 0 { + return "", "", fmt.Errorf("unable to find the IP address for the container %q: the server is ignored", container.Name) + } + return ip, port, nil +} + func (p *Provider) getServers(containers []dockerData) map[string]types.Server { var servers map[string]types.Server for _, container := range containers { - ip := p.getIPAddress(container) - if len(ip) == 0 { - log.Warnf("Unable to find the IP address for the container %q: the server is ignored.", container.Name) + ip, port, err := p.getIPPort(container) + if err != nil { + log.Warn(err) continue } @@ -337,7 +376,6 @@ func (p *Provider) getServers(containers []dockerData) map[string]types.Server { } protocol := label.GetStringValue(container.SegmentLabels, label.TraefikProtocol, label.DefaultProtocol) - port := getPort(container) serverURL := fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port)) diff --git a/provider/docker/config_container_docker_test.go b/provider/docker/config_container_docker_test.go index 457a195ce..6e340c74a 100644 --- a/provider/docker/config_container_docker_test.go +++ b/provider/docker/config_container_docker_test.go @@ -1287,12 +1287,173 @@ func TestDockerGetIPAddress(t *testing.T) { Network: "webnet", } - actual := provider.getIPAddress(dData) + actual := provider.getDeprecatedIPAddress(dData) assert.Equal(t, test.expected, actual) }) } } +func TestDockerGetIPPort(t *testing.T) { + testCases := []struct { + desc string + container docker.ContainerJSON + ip, port string + expectsError bool + }{ + { + desc: "label traefik.port not set, binding with ip:port should create a route to the bound ip:port", + container: containerJSON( + ports(nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: "1.2.3.4", + HostPort: "8081", + }, + }, + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + ip: "1.2.3.4", + port: "8081", + }, + { + desc: "label traefik.port set, multiple bindings on different ports, uses the label to select the correct (first) binding", + container: containerJSON( + labels(map[string]string{ + label.TraefikPort: "80", + }), + ports(nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: "1.2.3.4", + HostPort: "8081", + }, + }, + "443/tcp": []nat.PortBinding{ + { + HostIP: "5.6.7.8", + HostPort: "8082", + }, + }, + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + ip: "1.2.3.4", + port: "8081", + }, + { + desc: "label traefik.port set, multiple bindings on different ports, uses the label to select the correct (second) binding", + container: containerJSON( + labels(map[string]string{ + label.TraefikPort: "443", + }), + ports(nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: "1.2.3.4", + HostPort: "8081", + }, + }, + "443/tcp": []nat.PortBinding{ + { + HostIP: "5.6.7.8", + HostPort: "8082", + }, + }, + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + ip: "5.6.7.8", + port: "8082", + }, + { + desc: "label traefik.port set, single binding with ip:port for the label, creates the route", + container: containerJSON( + labels(map[string]string{ + label.TraefikPort: "443", + }), + ports(nat.PortMap{ + "443/tcp": []nat.PortBinding{ + { + HostIP: "5.6.7.8", + HostPort: "8082", + }, + }, + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + ip: "5.6.7.8", + port: "8082", + }, + { + desc: "label traefik.port not set, single binding with port only, server ignored", + container: containerJSON( + ports(nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostPort: "8082", + }, + }, + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + expectsError: true, + }, + { + desc: "label traefik.port not set, no binding, server ignored", + container: containerJSON( + withNetwork("testnet", ipv4("10.11.12.13"))), + expectsError: true, + }, + { + desc: "label traefik.port set, no binding on the corresponding port, server ignored", + container: containerJSON( + labels(map[string]string{ + label.TraefikPort: "80", + }), + ports(nat.PortMap{ + "443/tcp": []nat.PortBinding{ + { + HostIP: "5.6.7.8", + HostPort: "8082", + }, + }, + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + expectsError: true, + }, + { + desc: "label traefik.port set, no binding, server ignored", + container: containerJSON( + labels(map[string]string{ + label.TraefikPort: "80", + }), + withNetwork("testnet", ipv4("10.11.12.13"))), + expectsError: true, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + dData := parseContainer(test.container) + segmentProperties := label.ExtractTraefikLabels(dData.Labels) + dData.SegmentLabels = segmentProperties[""] + + provider := &Provider{ + Network: "webnet", + UseBindPortIP: true, + } + + actualIP, actualPort, actualError := provider.getIPPort(dData) + if test.expectsError { + require.Error(t, actualError) + } else { + require.NoError(t, actualError) + } + assert.Equal(t, test.ip, actualIP) + assert.Equal(t, test.port, actualPort) + }) + } +} + func TestDockerGetPort(t *testing.T) { testCases := []struct { container docker.ContainerJSON diff --git a/provider/docker/config_container_swarm_test.go b/provider/docker/config_container_swarm_test.go index ced5b0dc2..a8cbdca4e 100644 --- a/provider/docker/config_container_swarm_test.go +++ b/provider/docker/config_container_swarm_test.go @@ -933,7 +933,7 @@ func TestSwarmGetIPAddress(t *testing.T) { segmentProperties := label.ExtractTraefikLabels(dData.Labels) dData.SegmentLabels = segmentProperties[""] - actual := provider.getIPAddress(dData) + actual := provider.getDeprecatedIPAddress(dData) assert.Equal(t, test.expected, actual) }) } diff --git a/provider/docker/deprecated_config.go b/provider/docker/deprecated_config.go index 0954466a1..32f90c1f6 100644 --- a/provider/docker/deprecated_config.go +++ b/provider/docker/deprecated_config.go @@ -1,8 +1,10 @@ package docker import ( + "context" "math" "strconv" + "strings" "text/template" "github.com/BurntSushi/ty/fun" @@ -19,7 +21,7 @@ func (p *Provider) buildConfigurationV1(containersInspected []dockerData) *types "isBackendLBSwarm": isBackendLBSwarm, // Backend functions - "getIPAddress": p.getIPAddress, + "getIPAddress": p.getIPAddressV1, "getPort": getPortV1, "getWeight": getFuncIntLabelV1(label.TraefikWeight, label.DefaultWeight), "getProtocol": getFuncStringLabelV1(label.TraefikProtocol, label.DefaultProtocol), @@ -202,3 +204,60 @@ func (p Provider) containerFilterV1(container dockerData) bool { return true } + +func (p Provider) getIPAddressV1(container dockerData) string { + if value := label.GetStringValue(container.Labels, labelDockerNetwork, p.Network); value != "" { + networkSettings := container.NetworkSettings + if networkSettings.Networks != nil { + network := networkSettings.Networks[value] + if network != nil { + return network.Addr + } + + log.Warnf("Could not find network named '%s' for container '%s'! Maybe you're missing the project's prefix in the label? Defaulting to first available network.", value, container.Name) + } + } + + if container.NetworkSettings.NetworkMode.IsHost() { + if container.Node != nil { + if container.Node.IPAddress != "" { + return container.Node.IPAddress + } + } + return "127.0.0.1" + } + + if container.NetworkSettings.NetworkMode.IsContainer() { + dockerClient, err := p.createClient() + if err != nil { + log.Warnf("Unable to get IP address for container %s, error: %s", container.Name, err) + return "" + } + + connectedContainer := container.NetworkSettings.NetworkMode.ConnectedContainer() + containerInspected, err := dockerClient.ContainerInspect(context.Background(), connectedContainer) + if err != nil { + log.Warnf("Unable to get IP address for container %s : Failed to inspect container ID %s, error: %s", container.Name, connectedContainer, err) + return "" + } + return p.getIPAddress(parseContainer(containerInspected)) + } + + if p.UseBindPortIP { + port := getPortV1(container) + for netPort, portBindings := range container.NetworkSettings.Ports { + if strings.EqualFold(string(netPort), port+"/TCP") || strings.EqualFold(string(netPort), port+"/UDP") { + for _, p := range portBindings { + return p.HostIP + } + } + } + } + + for _, network := range container.NetworkSettings.Networks { + return network.Addr + } + + log.Warnf("Unable to find the IP address for the container %q.", container.Name) + return "" +} diff --git a/provider/docker/deprecated_container_docker_test.go b/provider/docker/deprecated_container_docker_test.go index 67771473e..b7359980c 100644 --- a/provider/docker/deprecated_container_docker_test.go +++ b/provider/docker/deprecated_container_docker_test.go @@ -898,7 +898,7 @@ func TestDockerGetIPAddressV1(t *testing.T) { t.Parallel() dData := parseContainer(test.container) provider := &Provider{} - actual := provider.getIPAddress(dData) + actual := provider.getDeprecatedIPAddress(dData) if actual != test.expected { t.Errorf("expected %q, got %q", test.expected, actual) } diff --git a/provider/docker/deprecated_container_swarm_test.go b/provider/docker/deprecated_container_swarm_test.go index 7803e91b0..d53adf1ea 100644 --- a/provider/docker/deprecated_container_swarm_test.go +++ b/provider/docker/deprecated_container_swarm_test.go @@ -667,7 +667,7 @@ func TestSwarmGetIPAddressV1(t *testing.T) { SwarmMode: true, } - actual := provider.getIPAddress(dData) + actual := provider.getDeprecatedIPAddress(dData) if actual != test.expected { t.Errorf("expected %q, got %q", test.expected, actual) } diff --git a/server/server.go b/server/server.go index 0869b1300..2739aad5a 100644 --- a/server/server.go +++ b/server/server.go @@ -40,6 +40,59 @@ import ( var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0) +func newHijackConnectionTracker() *hijackConnectionTracker { + return &hijackConnectionTracker{ + conns: make(map[net.Conn]struct{}), + } +} + +type hijackConnectionTracker struct { + conns map[net.Conn]struct{} + lock sync.RWMutex +} + +// AddHijackedConnection add a connection in the tracked connections list +func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) { + h.lock.Lock() + defer h.lock.Unlock() + h.conns[conn] = struct{}{} +} + +// RemoveHijackedConnection remove a connection from the tracked connections list +func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) { + h.lock.Lock() + defer h.lock.Unlock() + delete(h.conns, conn) +} + +// Shutdown wait for the connection closing +func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + h.lock.RLock() + if len(h.conns) == 0 { + return nil + } + h.lock.RUnlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Close close all the connections in the tracked connections list +func (h *hijackConnectionTracker) Close() { + for conn := range h.conns { + if err := conn.Close(); err != nil { + log.Errorf("Error while closing Hijacked conn: %v", err) + } + delete(h.conns, conn) + } +} + // Server is the reverse-proxy/load-balancer engine type Server struct { serverEntryPoints serverEntryPoints @@ -74,12 +127,41 @@ type EntryPoint struct { type serverEntryPoints map[string]*serverEntryPoint type serverEntryPoint struct { - httpServer *h2c.Server - listener net.Listener - httpRouter *middlewares.HandlerSwitcher - certs *traefiktls.CertificateStore - onDemandListener func(string) (*tls.Certificate, error) - tlsALPNGetter func(string) (*tls.Certificate, error) + httpServer *h2c.Server + listener net.Listener + httpRouter *middlewares.HandlerSwitcher + certs *traefiktls.CertificateStore + onDemandListener func(string) (*tls.Certificate, error) + tlsALPNGetter func(string) (*tls.Certificate, error) + hijackConnectionTracker *hijackConnectionTracker +} + +func (s serverEntryPoint) Shutdown(ctx context.Context) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := s.httpServer.Shutdown(ctx); err != nil { + if ctx.Err() == context.DeadlineExceeded { + log.Debugf("Wait server shutdown is over due to: %s", err) + err = s.httpServer.Close() + if err != nil { + log.Error(err) + } + } + } + }() + wg.Add(1) + go func() { + defer wg.Done() + if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil { + if ctx.Err() == context.DeadlineExceeded { + log.Debugf("Wait hijack connection is over due to: %s", err) + s.hijackConnectionTracker.Close() + } + } + }() + wg.Wait() } // NewServer returns an initialized Server. @@ -187,13 +269,7 @@ func (s *Server) Stop() { graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut) ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut) log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName) - if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil { - log.Debugf("Wait is over due to: %s", err) - err = serverEntryPoint.httpServer.Close() - if err != nil { - log.Error(err) - } - } + serverEntryPoint.Shutdown(ctx) cancel() log.Debugf("Entrypoint %s closed", serverEntryPointName) }(sepn, sep) @@ -447,6 +523,16 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer serverEntryPoint.httpServer = newSrv serverEntryPoint.listener = listener + serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker() + serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateHijacked: + serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn) + case http.StateClosed: + serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn) + } + } + return serverEntryPoint } diff --git a/server/server_configuration.go b/server/server_configuration.go index 5b4030b65..41ef7ab2c 100644 --- a/server/server_configuration.go +++ b/server/server_configuration.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net" "net/http" "reflect" "sort" @@ -245,6 +246,15 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration forward.Rewriter(rewriter), forward.ResponseModifier(responseModifier), forward.BufferPool(s.bufferPool), + forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) { + server := req.Context().Value(http.ServerContextKey).(*http.Server) + if server != nil { + connState := server.ConnState + if connState != nil { + connState(conn, http.StateClosed) + } + } + }), ) if err != nil { return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err) diff --git a/vendor/github.com/vulcand/oxy/forward/fwd.go b/vendor/github.com/vulcand/oxy/forward/fwd.go index ec4bea59f..337d5eff5 100644 --- a/vendor/github.com/vulcand/oxy/forward/fwd.go +++ b/vendor/github.com/vulcand/oxy/forward/fwd.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter { } } +// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed +func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketConnectionClosedHook = hook + return nil + } +} + // ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { @@ -188,7 +197,8 @@ type httpForwarder struct { log OxyLogger - bufferPool httputil.BufferPool + bufferPool httputil.BufferPool + websocketConnectionClosedHook func(req *http.Request, conn net.Conn) } const defaultFlushInterval = time.Duration(100) * time.Millisecond @@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } - defer underlyingConn.Close() - defer targetConn.Close() + defer func() { + underlyingConn.Close() + targetConn.Close() + if f.websocketConnectionClosedHook != nil { + f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn()) + } + }() errClient := make(chan error, 1) errBackend := make(chan error, 1)