diff --git a/integration/fixtures/grpc/config_retry.toml b/integration/fixtures/grpc/config_retry.toml new file mode 100644 index 000000000..872ab8885 --- /dev/null +++ b/integration/fixtures/grpc/config_retry.toml @@ -0,0 +1,34 @@ +[serversTransport] +rootCAs = [ """{{ .CertContent }}""" ] + +[entryPoints] + [entryPoints.https] + address = ":4443" + [entryPoints.https.tls] + [entryPoints.https.tls.DefaultCertificate] + certFile = """{{ .CertContent }}""" + keyFile = """{{ .KeyContent }}""" + + +[api] + +[providers] + [providers.file] + +[routers] + [routers.router1] + rule = "Host(`127.0.0.1`)" + service = "service1" + middlewares = ["retryer"] + +[middlewares] + [middlewares.retryer.retry] + Attempts = 2 + +[services] + [services.service1.loadbalancer] + [services.service1.loadbalancer.responseForwarding] + flushInterval="1ms" + [[services.service1.loadbalancer.servers]] + url = "https://127.0.0.1:{{ .GRPCServerPort }}" + weight = 1 diff --git a/integration/grpc_test.go b/integration/grpc_test.go index 990e27bee..2e7e2e1d1 100644 --- a/integration/grpc_test.go +++ b/integration/grpc_test.go @@ -423,3 +423,45 @@ func (s *GRPCSuite) TestGRPCBufferWithFlushInterval(c *check.C) { }) c.Assert(err, check.IsNil) } + +func (s *GRPCSuite) TestGRPCWithRetry(c *check.C) { + lis, err := net.Listen("tcp", ":0") + _, port, err := net.SplitHostPort(lis.Addr().String()) + c.Assert(err, check.IsNil) + + go func() { + err := startGRPCServer(lis, &myserver{}) + c.Log(err) + c.Assert(err, check.IsNil) + }() + + file := s.adaptFile(c, "fixtures/grpc/config_retry.toml", struct { + CertContent string + KeyContent string + GRPCServerPort string + }{ + CertContent: string(LocalhostCert), + KeyContent: string(LocalhostKey), + GRPCServerPort: port, + }) + + defer os.Remove(file) + cmd, display := s.traefikCmd(withConfigFile(file)) + defer display(c) + + err = cmd.Start() + c.Assert(err, check.IsNil) + defer cmd.Process.Kill() + + // wait for Traefik + err = try.GetRequest("http://127.0.0.1:8080/api/providers/file/routers", 1*time.Second, try.BodyContains("Host(`127.0.0.1`)")) + c.Assert(err, check.IsNil) + + var response string + err = try.Do(1*time.Second, func() error { + response, err = callHelloClientGRPC("World", true) + return err + }) + c.Assert(err, check.IsNil) + c.Assert(response, check.Equals, "Hello World") +} diff --git a/middlewares/retry/retry.go b/middlewares/retry/retry.go index 11cc337bf..7c0d0a835 100644 --- a/middlewares/retry/retry.go +++ b/middlewares/retry/retry.go @@ -132,6 +132,7 @@ type responseWriterWithoutCloseNotify struct { responseWriter http.ResponseWriter headers http.Header shouldRetry bool + written bool } func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool { @@ -143,6 +144,9 @@ func (r *responseWriterWithoutCloseNotify) DisableRetries() { } func (r *responseWriterWithoutCloseNotify) Header() http.Header { + if r.written { + return r.responseWriter.Header() + } return r.headers } @@ -177,6 +181,7 @@ func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) { } r.responseWriter.WriteHeader(code) + r.written = true } func (r *responseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { diff --git a/old/middlewares/retry.go b/old/middlewares/retry.go index b6407ac0e..062bd2240 100644 --- a/old/middlewares/retry.go +++ b/old/middlewares/retry.go @@ -110,6 +110,7 @@ type retryResponseWriterWithoutCloseNotify struct { responseWriter http.ResponseWriter headers http.Header shouldRetry bool + written bool } func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool { @@ -121,6 +122,9 @@ func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() { } func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header { + if rr.written { + return rr.responseWriter.Header() + } return rr.headers } @@ -155,6 +159,7 @@ func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) { } rr.responseWriter.WriteHeader(code) + rr.written = true } func (rr *retryResponseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {