1
0
Fork 0

Bump AWS SDK to v2

This commit is contained in:
Eng Zer Jun 2025-03-10 18:50:04 +08:00 committed by GitHub
parent 7cfd10db62
commit 14e400bcd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 352 additions and 387 deletions

View file

@ -1,10 +1,13 @@
package ecs
import "github.com/aws/aws-sdk-go/service/ecs"
import (
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
)
func instance(ops ...func(*ecsInstance)) ecsInstance {
e := &ecsInstance{
containerDefinition: &ecs.ContainerDefinition{},
containerDefinition: &ecstypes.ContainerDefinition{},
}
for _, op := range ops {
@ -36,7 +39,7 @@ func iMachine(opts ...func(*machine)) func(*ecsInstance) {
}
}
func mState(state string) func(*machine) {
func mState(state ec2types.InstanceStateName) func(*machine) {
return func(m *machine) {
m.state = state
}
@ -48,7 +51,7 @@ func mPrivateIP(ip string) func(*machine) {
}
}
func mHealthStatus(status string) func(*machine) {
func mHealthStatus(status ecstypes.HealthStatus) func(*machine) {
return func(m *machine) {
m.healthStatus = status
}
@ -64,10 +67,10 @@ func mPorts(opts ...func(*portMapping)) func(*machine) {
}
}
func mPort(containerPort, hostPort int32, protocol string) func(*portMapping) {
func mPort(containerPort, hostPort int32, protocol ecstypes.TransportProtocol) func(*portMapping) {
return func(pm *portMapping) {
pm.containerPort = int64(containerPort)
pm.hostPort = int64(hostPort)
pm.containerPort = containerPort
pm.hostPort = hostPort
pm.protocol = protocol
}
}

View file

@ -6,9 +6,9 @@ import (
"fmt"
"net"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
"github.com/docker/go-connections/nat"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/config/label"
@ -163,12 +163,12 @@ func (p *Provider) filterInstance(ctx context.Context, instance ecsInstance) boo
return false
}
if strings.ToLower(instance.machine.state) != ec2.InstanceStateNameRunning {
if instance.machine.state != ec2types.InstanceStateNameRunning {
logger.Debugf("Filtering ecs instance with an incorrect state %s (%s) (state = %s)", instance.Name, instance.ID, instance.machine.state)
return false
}
if instance.machine.healthStatus == "UNHEALTHY" {
if instance.machine.healthStatus == ecstypes.HealthStatusUnhealthy {
logger.Debugf("Filtering unhealthy ecs instance %s (%s)", instance.Name, instance.ID)
return false
}
@ -293,9 +293,9 @@ func (p *Provider) getIPPort(instance ecsInstance, serverPort string) (string, s
func getPort(instance ecsInstance, serverPort string) string {
if len(serverPort) > 0 {
for _, port := range instance.machine.ports {
containerPort := strconv.FormatInt(port.containerPort, 10)
containerPort := strconv.FormatInt(int64(port.containerPort), 10)
if serverPort == containerPort {
return strconv.FormatInt(port.hostPort, 10)
return strconv.FormatInt(int64(port.hostPort), 10)
}
}
@ -304,7 +304,7 @@ func getPort(instance ecsInstance, serverPort string) string {
var ports []nat.Port
for _, port := range instance.machine.ports {
natPort, err := nat.NewPort(port.protocol, strconv.FormatInt(port.hostPort, 10))
natPort, err := nat.NewPort(string(port.protocol), strconv.FormatInt(int64(port.hostPort), 10))
if err != nil {
continue
}

View file

@ -4,7 +4,8 @@ import (
"context"
"testing"
"github.com/aws/aws-sdk-go/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
@ -29,10 +30,10 @@ func TestDefaultRule(t *testing.T) {
id("1"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("10.0.0.1"),
mPorts(
mPort(0, 1337, "TCP"),
mPort(0, 1337, ecstypes.TransportProtocolTcp),
),
),
),
@ -83,10 +84,10 @@ func TestDefaultRule(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -139,10 +140,10 @@ func TestDefaultRule(t *testing.T) {
"traefik.domain": "foo.bar",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -193,10 +194,10 @@ func TestDefaultRule(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -241,10 +242,10 @@ func TestDefaultRule(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -289,10 +290,10 @@ func TestDefaultRule(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -380,10 +381,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.test": "",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -418,10 +419,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tcp.services.test": "",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -456,10 +457,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.udp.services.test": "",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -492,10 +493,10 @@ func Test_buildConfiguration(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -545,10 +546,10 @@ func Test_buildConfiguration(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -556,10 +557,10 @@ func Test_buildConfiguration(t *testing.T) {
name("Test2"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -625,10 +626,10 @@ func Test_buildConfiguration(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -637,10 +638,10 @@ func Test_buildConfiguration(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -695,10 +696,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -752,10 +753,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.service": "Service1",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -806,10 +807,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -861,10 +862,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -917,10 +918,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service2.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -977,10 +978,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.service": "Service1",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1032,10 +1033,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1046,10 +1047,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "false",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1091,10 +1092,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "false",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1105,10 +1106,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1119,10 +1120,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1164,10 +1165,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1178,10 +1179,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1236,10 +1237,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1298,10 +1299,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1312,10 +1313,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1377,10 +1378,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1391,10 +1392,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "41",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1450,10 +1451,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1464,10 +1465,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "41",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1478,10 +1479,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "40",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.3"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1540,10 +1541,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1554,10 +1555,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`bar.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1607,10 +1608,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1621,10 +1622,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`bar.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1635,10 +1636,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foobar.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.3"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1691,10 +1692,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1706,10 +1707,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1763,10 +1764,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1776,10 +1777,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Router1.rule": "Host(`foo.com`)",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1835,10 +1836,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.wrong.label": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -1891,10 +1892,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.LoadBalancer.server.port": "80",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(80, 8080, "tcp"),
mPort(80, 8080, ecstypes.TransportProtocolTcp),
),
),
),
@ -1947,10 +1948,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.LoadBalancer.server.port": "8040",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(80, 8080, "tcp"),
mPort(80, 8080, ecstypes.TransportProtocolTcp),
),
),
),
@ -2007,11 +2008,11 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service2.LoadBalancer.server.port": "4444",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(4444, 32123, "tcp"),
mPort(4445, 32124, "tcp"),
mPort(4444, 32123, ecstypes.TransportProtocolTcp),
mPort(4445, 32124, ecstypes.TransportProtocolTcp),
),
),
),
@ -2077,10 +2078,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service2.LoadBalancer.server.port": "8080",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2134,7 +2135,7 @@ func Test_buildConfiguration(t *testing.T) {
name("Test"),
labels(map[string]string{}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(),
),
@ -2170,7 +2171,7 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware1.inflightreq.amount": "42",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(),
),
@ -2206,10 +2207,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.enable": "false",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2244,11 +2245,11 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.enable": "false",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mHealthStatus("UNHEALTHY"),
mHealthStatus(ecstypes.HealthStatusUnhealthy),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2283,10 +2284,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.enable": "false",
}),
iMachine(
mState(ec2.InstanceStateNamePending),
mState(ec2types.InstanceStateNamePending),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2321,10 +2322,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tags": "foo",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2360,10 +2361,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tags": "foo",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2417,10 +2418,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.routers.Test.middlewares": "Middleware1",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2484,10 +2485,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tcp.routers.Test.middlewares": "Middleware1",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2546,10 +2547,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tcp.routers.foo.tls": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2601,10 +2602,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.udp.routers.foo.entrypoints": "mydns",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "udp"),
mPort(0, 80, ecstypes.TransportProtocolUdp),
),
),
),
@ -2654,10 +2655,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tcp.routers.foo.tls": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2705,10 +2706,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tcp.services.foo.loadbalancer.server.port": "80",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(80, 8080, "tcp"),
mPort(80, 8080, ecstypes.TransportProtocolTcp),
),
),
),
@ -2763,10 +2764,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.udp.services.foo.loadbalancer.server.port": "80",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(80, 8080, "udp"),
mPort(80, 8080, ecstypes.TransportProtocolUdp),
),
),
),
@ -2818,10 +2819,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2834,10 +2835,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.http.services.Service1.loadbalancer.passhostheader": "true",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.2"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2910,10 +2911,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.udp.services.foo.loadbalancer.server.port": "8080",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),
@ -2959,10 +2960,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tcp.services.foo.loadbalancer.terminationdelay": "200",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(80, 8080, "tcp"),
mPort(80, 8080, ecstypes.TransportProtocolTcp),
),
),
),
@ -3010,10 +3011,10 @@ func Test_buildConfiguration(t *testing.T) {
"traefik.tls.stores.default.defaultgeneratedcert.domain.sans": "foobar, fiibar",
}),
iMachine(
mState(ec2.InstanceStateNameRunning),
mState(ec2types.InstanceStateNameRunning),
mPrivateIP("127.0.0.1"),
mPorts(
mPort(0, 80, "tcp"),
mPort(0, 80, ecstypes.TransportProtocolTcp),
),
),
),

View file

@ -3,21 +3,22 @@ package ecs
import (
"context"
"fmt"
"os"
"iter"
"slices"
"strings"
"text/template"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/ecs"
ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
"github.com/aws/aws-sdk-go-v2/service/ssm"
ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types"
"github.com/aws/smithy-go/logging"
"github.com/cenkalti/backoff/v4"
"github.com/patrickmn/go-cache"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
@ -47,29 +48,29 @@ type Provider struct {
type ecsInstance struct {
Name string
ID string
containerDefinition *ecs.ContainerDefinition
containerDefinition *ecstypes.ContainerDefinition
machine *machine
Labels map[string]string
ExtraConf configuration
}
type portMapping struct {
containerPort int64
hostPort int64
protocol string
containerPort int32
hostPort int32
protocol ecstypes.TransportProtocol
}
type machine struct {
state string
state ec2types.InstanceStateName
privateIP string
ports []portMapping
healthStatus string
healthStatus ecstypes.HealthStatus
}
type awsClient struct {
ecs *ecs.ECS
ec2 *ec2.EC2
ssm *ssm.SSM
ecs *ecs.Client
ec2 *ec2.Client
ssm *ssm.Client
}
// DefaultTemplateRule The default template for the default rule.
@ -100,56 +101,40 @@ func (p *Provider) Init() error {
return nil
}
func (p *Provider) createClient(logger log.Logger) (*awsClient, error) {
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
func (p *Provider) createClient(ctx context.Context, logger log.Logger) (*awsClient, error) {
optFns := []func(*config.LoadOptions) error{
config.WithLogger(logging.LoggerFunc(func(_ logging.Classification, format string, args ...interface{}) {
logger.Debugf(format, args...)
})),
}
if p.Region != "" {
optFns = append(optFns, config.WithRegion(p.Region))
} else {
logger.Infoln("No region provided, will retrieve region from the EC2 Metadata service")
optFns = append(optFns, config.WithEC2IMDSRegion())
}
if p.AccessKeyID != "" && p.SecretAccessKey != "" {
// From https://docs.aws.amazon.com/sdk-for-go/v2/developer-guide/configure-gosdk.html#specify-credentials-programmatically:
// "If you explicitly provide credentials, as in this example, the SDK uses only those credentials."
// this makes sure that user-defined credentials always have the highest priority
staticCreds := aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(p.AccessKeyID, p.SecretAccessKey, ""))
optFns = append(optFns, config.WithCredentialsProvider(staticCreds))
// If the access key and secret access key are not provided, config.LoadDefaultConfig
// will look for the credentials in the default credential chain.
// See https://docs.aws.amazon.com/sdk-for-go/v2/developer-guide/configure-gosdk.html#specifying-credentials.
}
cfg, err := config.LoadDefaultConfig(ctx, optFns...)
if err != nil {
return nil, err
}
ec2meta := ec2metadata.New(sess)
if p.Region == "" && ec2meta.Available() {
logger.Infoln("No region provided, querying instance metadata endpoint...")
identity, err := ec2meta.GetInstanceIdentityDocument()
if err != nil {
return nil, err
}
p.Region = identity.Region
}
cfg := aws.NewConfig().
WithCredentials(credentials.NewChainCredentials([]credentials.Provider{
&credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: p.AccessKeyID,
SecretAccessKey: p.SecretAccessKey,
},
},
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{},
defaults.RemoteCredProvider(*(defaults.Config()), defaults.Handlers()),
stscreds.NewWebIdentityRoleProviderWithOptions(
sts.New(sess),
os.Getenv("AWS_ROLE_ARN"),
"",
stscreds.FetchTokenPath(os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE")),
),
}))
// Set the region if it is defined by the user or resolved from the EC2 metadata.
if p.Region != "" {
cfg.Region = &p.Region
}
cfg.WithLogger(aws.LoggerFunc(func(args ...interface{}) {
logger.Debug(args...)
}))
return &awsClient{
ecs.New(sess, cfg),
ec2.New(sess, cfg),
ssm.New(sess, cfg),
ecs.NewFromConfig(cfg),
ec2.NewFromConfig(cfg),
ssm.NewFromConfig(cfg),
}, nil
}
@ -160,7 +145,7 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe.
logger := log.FromContext(ctxLog)
operation := func() error {
awsClient, err := p.createClient(logger)
awsClient, err := p.createClient(ctxLog, logger)
if err != nil {
return fmt.Errorf("unable to create AWS client: %w", err)
}
@ -218,28 +203,19 @@ func (p *Provider) loadConfiguration(ctx context.Context, client *awsClient, con
func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsInstance, error) {
logger := log.FromContext(ctx)
var clustersArn []*string
var clusters []string
if p.AutoDiscoverClusters {
input := &ecs.ListClustersInput{}
for {
result, err := client.ecs.ListClusters(input)
paginator := ecs.NewListClustersPaginator(client.ecs, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, err
}
if result != nil {
clustersArn = append(clustersArn, result.ClusterArns...)
input.NextToken = result.NextToken
if result.NextToken == nil {
break
}
} else {
break
}
}
for _, cArn := range clustersArn {
clusters = append(clusters, *cArn)
clusters = append(clusters, page.ClusterArns...)
}
} else {
clusters = p.Clusters
@ -251,13 +227,19 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
for _, c := range clusters {
input := &ecs.ListTasksInput{
Cluster: &c,
DesiredStatus: aws.String(ecs.DesiredStatusRunning),
DesiredStatus: ecstypes.DesiredStatusRunning,
}
tasks := make(map[string]*ecs.Task)
err := client.ecs.ListTasksPagesWithContext(ctx, input, func(page *ecs.ListTasksOutput, lastPage bool) bool {
tasks := make(map[string]ecstypes.Task)
paginator := ecs.NewListTasksPaginator(client.ecs, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("listing tasks: %w", err)
}
if len(page.TaskArns) > 0 {
resp, err := client.ecs.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
resp, err := client.ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{
Tasks: page.TaskArns,
Cluster: &c,
})
@ -265,16 +247,12 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
logger.Errorf("Unable to describe tasks for %v", page.TaskArns)
} else {
for _, t := range resp.Tasks {
if aws.StringValue(t.LastStatus) == ecs.DesiredStatusRunning {
tasks[aws.StringValue(t.TaskArn)] = t
if aws.ToString(t.LastStatus) == string(ecstypes.DesiredStatusRunning) {
tasks[aws.ToString(t.TaskArn)] = t
}
}
}
}
return !lastPage
})
if err != nil {
return nil, fmt.Errorf("listing tasks: %w", err)
}
// Skip to the next cluster if there are no tasks found on
@ -288,7 +266,7 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
return nil, err
}
miInstances := make(map[string]*ssm.InstanceInformation)
miInstances := make(map[string]ssmtypes.InstanceInformation)
if p.ECSAnywhere {
// Try looking up for instances on ECS Anywhere
miInstances, err = p.lookupMiInstances(ctx, client, &c, tasks)
@ -303,74 +281,67 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
}
for key, task := range tasks {
containerInstance := ec2Instances[aws.StringValue(task.ContainerInstanceArn)]
containerInstance, hasContainerInstance := ec2Instances[aws.ToString(task.ContainerInstanceArn)]
taskDef := taskDefinitions[key]
for _, container := range task.Containers {
var containerDefinition *ecs.ContainerDefinition
var containerDefinition *ecstypes.ContainerDefinition
for _, def := range taskDef.ContainerDefinitions {
if aws.StringValue(container.Name) == aws.StringValue(def.Name) {
containerDefinition = def
if aws.ToString(container.Name) == aws.ToString(def.Name) {
containerDefinition = &def
break
}
}
if containerDefinition == nil {
logger.Debugf("Unable to find container definition for %s", aws.StringValue(container.Name))
logger.Debugf("Unable to find container definition for %s", aws.ToString(container.Name))
continue
}
var mach *machine
if aws.StringValue(taskDef.NetworkMode) == "awsvpc" && len(task.Attachments) != 0 {
if taskDef.NetworkMode == ecstypes.NetworkModeAwsvpc && len(task.Attachments) != 0 {
if len(container.NetworkInterfaces) == 0 {
logger.Errorf("Skip container %s: no network interfaces", aws.StringValue(container.Name))
logger.Errorf("Skip container %s: no network interfaces", aws.ToString(container.Name))
continue
}
var ports []portMapping
for _, mapping := range containerDefinition.PortMappings {
if mapping != nil {
protocol := "TCP"
if aws.StringValue(mapping.Protocol) == "udp" {
protocol = "UDP"
}
ports = append(ports, portMapping{
hostPort: aws.Int64Value(mapping.HostPort),
containerPort: aws.Int64Value(mapping.ContainerPort),
protocol: protocol,
})
}
ports = append(ports, portMapping{
hostPort: aws.ToInt32(mapping.HostPort),
containerPort: aws.ToInt32(mapping.ContainerPort),
protocol: mapping.Protocol,
})
}
mach = &machine{
privateIP: aws.StringValue(container.NetworkInterfaces[0].PrivateIpv4Address),
privateIP: aws.ToString(container.NetworkInterfaces[0].PrivateIpv4Address),
ports: ports,
state: aws.StringValue(task.LastStatus),
healthStatus: aws.StringValue(task.HealthStatus),
state: ec2types.InstanceStateName(strings.ToLower(aws.ToString(task.LastStatus))),
healthStatus: task.HealthStatus,
}
} else {
miContainerInstance := miInstances[aws.StringValue(task.ContainerInstanceArn)]
if containerInstance == nil && miContainerInstance == nil {
logger.Errorf("Unable to find container instance information for %s", aws.StringValue(container.Name))
miContainerInstance, hasMiContainerInstance := miInstances[aws.ToString(task.ContainerInstanceArn)]
if !hasContainerInstance && !hasMiContainerInstance {
logger.Errorf("Unable to find container instance information for %s", aws.ToString(container.Name))
continue
}
var ports []portMapping
for _, mapping := range container.NetworkBindings {
if mapping != nil {
ports = append(ports, portMapping{
hostPort: aws.Int64Value(mapping.HostPort),
containerPort: aws.Int64Value(mapping.ContainerPort),
})
}
ports = append(ports, portMapping{
hostPort: aws.ToInt32(mapping.HostPort),
containerPort: aws.ToInt32(mapping.ContainerPort),
protocol: mapping.Protocol,
})
}
var privateIPAddress, stateName string
if containerInstance != nil {
privateIPAddress = aws.StringValue(containerInstance.PrivateIpAddress)
stateName = aws.StringValue(containerInstance.State.Name)
} else if miContainerInstance != nil {
privateIPAddress = aws.StringValue(miContainerInstance.IPAddress)
stateName = aws.StringValue(task.LastStatus)
var privateIPAddress string
var stateName ec2types.InstanceStateName
if hasContainerInstance {
privateIPAddress = aws.ToString(containerInstance.PrivateIpAddress)
stateName = containerInstance.State.Name
} else if hasMiContainerInstance {
privateIPAddress = aws.ToString(miContainerInstance.IPAddress)
stateName = ec2types.InstanceStateName(strings.ToLower(aws.ToString(task.LastStatus)))
}
mach = &machine{
@ -381,11 +352,11 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
}
instance := ecsInstance{
Name: fmt.Sprintf("%s-%s", strings.Replace(aws.StringValue(task.Group), ":", "-", 1), *container.Name),
Name: fmt.Sprintf("%s-%s", strings.Replace(aws.ToString(task.Group), ":", "-", 1), aws.ToString(container.Name)),
ID: key[len(key)-12:],
containerDefinition: containerDefinition,
machine: mach,
Labels: aws.StringValueMap(containerDefinition.DockerLabels),
Labels: containerDefinition.DockerLabels,
}
extraConf, err := p.getConfiguration(instance)
@ -403,21 +374,21 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
return instances, nil
}
func (p *Provider) lookupMiInstances(ctx context.Context, client *awsClient, clusterName *string, ecsDatas map[string]*ecs.Task) (map[string]*ssm.InstanceInformation, error) {
func (p *Provider) lookupMiInstances(ctx context.Context, client *awsClient, clusterName *string, ecsDatas map[string]ecstypes.Task) (map[string]ssmtypes.InstanceInformation, error) {
instanceIDs := make(map[string]string)
miInstances := make(map[string]*ssm.InstanceInformation)
miInstances := make(map[string]ssmtypes.InstanceInformation)
var containerInstancesArns []*string
var instanceArns []*string
var containerInstancesArns []string
var instanceArns []string
for _, task := range ecsDatas {
if task.ContainerInstanceArn != nil {
containerInstancesArns = append(containerInstancesArns, task.ContainerInstanceArn)
containerInstancesArns = append(containerInstancesArns, *task.ContainerInstanceArn)
}
}
for _, arns := range p.chunkIDs(containerInstancesArns) {
resp, err := client.ecs.DescribeContainerInstancesWithContext(ctx, &ecs.DescribeContainerInstancesInput{
for arns := range chunkIDs(containerInstancesArns) {
resp, err := client.ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{
ContainerInstances: arns,
Cluster: clusterName,
})
@ -426,23 +397,21 @@ func (p *Provider) lookupMiInstances(ctx context.Context, client *awsClient, clu
}
for _, container := range resp.ContainerInstances {
instanceIDs[aws.StringValue(container.Ec2InstanceId)] = aws.StringValue(container.ContainerInstanceArn)
instanceIDs[aws.ToString(container.Ec2InstanceId)] = aws.ToString(container.ContainerInstanceArn)
// Disallow EC2 Instance IDs
// This prevents considering EC2 instances in ECS
// and getting InvalidInstanceID.Malformed error when calling the describe-instances endpoint.
if !strings.HasPrefix(aws.StringValue(container.Ec2InstanceId), "mi-") {
continue
if strings.HasPrefix(aws.ToString(container.Ec2InstanceId), "mi-") {
instanceArns = append(instanceArns, *container.Ec2InstanceId)
}
instanceArns = append(instanceArns, container.Ec2InstanceId)
}
}
if len(instanceArns) > 0 {
for _, ids := range p.chunkIDs(instanceArns) {
for ids := range chunkIDs(instanceArns) {
input := &ssm.DescribeInstanceInformationInput{
Filters: []*ssm.InstanceInformationStringFilter{
Filters: []ssmtypes.InstanceInformationStringFilter{
{
Key: aws.String("InstanceIds"),
Values: ids,
@ -450,18 +419,18 @@ func (p *Provider) lookupMiInstances(ctx context.Context, client *awsClient, clu
},
}
err := client.ssm.DescribeInstanceInformationPagesWithContext(ctx, input, func(page *ssm.DescribeInstanceInformationOutput, lastPage bool) bool {
if len(page.InstanceInformationList) > 0 {
for _, i := range page.InstanceInformationList {
if i.InstanceId != nil {
miInstances[instanceIDs[aws.StringValue(i.InstanceId)]] = i
}
paginator := ssm.NewDescribeInstanceInformationPaginator(client.ssm, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("describing instances: %w", err)
}
for _, i := range page.InstanceInformationList {
if i.InstanceId != nil {
miInstances[instanceIDs[aws.ToString(i.InstanceId)]] = i
}
}
return !lastPage
})
if err != nil {
return nil, fmt.Errorf("describing instances: %w", err)
}
}
}
@ -469,21 +438,21 @@ func (p *Provider) lookupMiInstances(ctx context.Context, client *awsClient, clu
return miInstances, nil
}
func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, clusterName *string, ecsDatas map[string]*ecs.Task) (map[string]*ec2.Instance, error) {
func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, clusterName *string, ecsDatas map[string]ecstypes.Task) (map[string]ec2types.Instance, error) {
instanceIDs := make(map[string]string)
ec2Instances := make(map[string]*ec2.Instance)
ec2Instances := make(map[string]ec2types.Instance)
var containerInstancesArns []*string
var instanceArns []*string
var containerInstancesArns []string
var instanceArns []string
for _, task := range ecsDatas {
if task.ContainerInstanceArn != nil {
containerInstancesArns = append(containerInstancesArns, task.ContainerInstanceArn)
containerInstancesArns = append(containerInstancesArns, *task.ContainerInstanceArn)
}
}
for _, arns := range p.chunkIDs(containerInstancesArns) {
resp, err := client.ecs.DescribeContainerInstancesWithContext(ctx, &ecs.DescribeContainerInstancesInput{
for arns := range chunkIDs(containerInstancesArns) {
resp, err := client.ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{
ContainerInstances: arns,
Cluster: clusterName,
})
@ -492,38 +461,38 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
}
for _, container := range resp.ContainerInstances {
instanceIDs[aws.StringValue(container.Ec2InstanceId)] = aws.StringValue(container.ContainerInstanceArn)
instanceIDs[aws.ToString(container.Ec2InstanceId)] = aws.ToString(container.ContainerInstanceArn)
// Disallow Instance IDs of the form mi-*
// This prevents considering external instances in ECS Anywhere setups
// and getting InvalidInstanceID.Malformed error when calling the describe-instances endpoint.
if strings.HasPrefix(aws.StringValue(container.Ec2InstanceId), "mi-") {
if strings.HasPrefix(aws.ToString(container.Ec2InstanceId), "mi-") {
continue
}
instanceArns = append(instanceArns, container.Ec2InstanceId)
if container.Ec2InstanceId != nil {
instanceArns = append(instanceArns, *container.Ec2InstanceId)
}
}
}
if len(instanceArns) > 0 {
for _, ids := range p.chunkIDs(instanceArns) {
for ids := range chunkIDs(instanceArns) {
input := &ec2.DescribeInstancesInput{
InstanceIds: ids,
}
err := client.ec2.DescribeInstancesPagesWithContext(ctx, input, func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
if len(page.Reservations) > 0 {
for _, r := range page.Reservations {
for _, i := range r.Instances {
if i.InstanceId != nil {
ec2Instances[instanceIDs[aws.StringValue(i.InstanceId)]] = i
}
paginator := ec2.NewDescribeInstancesPaginator(client.ec2, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("describing instances: %w", err)
}
for _, r := range page.Reservations {
for _, i := range r.Instances {
if i.InstanceId != nil {
ec2Instances[instanceIDs[aws.ToString(i.InstanceId)]] = i
}
}
}
return !lastPage
})
if err != nil {
return nil, fmt.Errorf("describing instances: %w", err)
}
}
}
@ -531,16 +500,16 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
return ec2Instances, nil
}
func (p *Provider) lookupTaskDefinitions(ctx context.Context, client *awsClient, taskDefArns map[string]*ecs.Task) (map[string]*ecs.TaskDefinition, error) {
func (p *Provider) lookupTaskDefinitions(ctx context.Context, client *awsClient, taskDefArns map[string]ecstypes.Task) (map[string]*ecstypes.TaskDefinition, error) {
logger := log.FromContext(ctx)
taskDef := make(map[string]*ecs.TaskDefinition)
taskDef := make(map[string]*ecstypes.TaskDefinition)
for arn, task := range taskDefArns {
if definition, ok := existingTaskDefCache.Get(arn); ok {
taskDef[arn] = definition.(*ecs.TaskDefinition)
taskDef[arn] = definition.(*ecstypes.TaskDefinition)
logger.Debugf("Found cached task definition for %s. Skipping the call", arn)
} else {
resp, err := client.ecs.DescribeTaskDefinitionWithContext(ctx, &ecs.DescribeTaskDefinitionInput{
resp, err := client.ecs.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{
TaskDefinition: task.TaskDefinitionArn,
})
if err != nil {
@ -556,16 +525,6 @@ func (p *Provider) lookupTaskDefinitions(ctx context.Context, client *awsClient,
// chunkIDs ECS expects no more than 100 parameters be passed to a API call;
// thus, pack each string into an array capped at 100 elements.
func (p *Provider) chunkIDs(ids []*string) [][]*string {
var chunked [][]*string
for i := 0; i < len(ids); i += 100 {
var sliceEnd int
if i+100 < len(ids) {
sliceEnd = i + 100
} else {
sliceEnd = len(ids)
}
chunked = append(chunked, ids[i:sliceEnd])
}
return chunked
func chunkIDs(ids []string) iter.Seq[[]string] {
return slices.Chunk(ids, 100)
}

View file

@ -3,13 +3,10 @@ package ecs
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)
func TestChunkIDs(t *testing.T) {
provider := &Provider{}
testCases := []struct {
desc string
count int
@ -71,13 +68,13 @@ func TestChunkIDs(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var IDs []*string
var IDs []string
for range test.count {
IDs = append(IDs, aws.String("a"))
IDs = append(IDs, "a")
}
var outCount []int
for _, el := range provider.chunkIDs(IDs) {
for el := range chunkIDs(IDs) {
outCount = append(outCount, len(el))
}