Skip to content

Commit 25b9705

Browse files
committed
Rewrite GetClusterSubnets() using EC2 specific API
1 parent b42cd80 commit 25b9705

File tree

6 files changed

+164
-234
lines changed

6 files changed

+164
-234
lines changed

internal/alb/lb/loadbalancer.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ func (controller *defaultController) resolveSubnets(ctx context.Context, scheme
380380
}
381381

382382
func (controller *defaultController) clusterSubnets(ctx context.Context, scheme string) ([]string, error) {
383-
var subnetIds []string
384383
var useableSubnets []*ec2.Subnet
385384
var out []string
386385
var key string
@@ -393,27 +392,12 @@ func (controller *defaultController) clusterSubnets(ctx context.Context, scheme
393392
return nil, fmt.Errorf("invalid scheme [%s]", scheme)
394393
}
395394

396-
clusterSubnets, err := controller.cloud.GetClusterSubnets()
395+
clusterSubnets, err := controller.cloud.GetClusterSubnets(key)
397396
if err != nil {
398-
return nil, fmt.Errorf("failed to get AWS tags. Error: %s", err.Error())
397+
return nil, fmt.Errorf("unable to fetch subnets. Error: %s", err.Error())
399398
}
400399

401-
for arn, subnetTags := range clusterSubnets {
402-
for _, tag := range subnetTags {
403-
if aws.StringValue(tag.Key) == key {
404-
p := strings.Split(arn, "/")
405-
subnetID := p[len(p)-1]
406-
subnetIds = append(subnetIds, subnetID)
407-
}
408-
}
409-
}
410-
411-
o, err := controller.cloud.GetSubnetsByNameOrID(ctx, subnetIds)
412-
if err != nil {
413-
return nil, fmt.Errorf("unable to fetch subnets due to %v", err)
414-
}
415-
416-
for _, subnet := range o {
400+
for _, subnet := range clusterSubnets {
417401
if subnetIsUsable(subnet, useableSubnets) {
418402
useableSubnets = append(useableSubnets, subnet)
419403
out = append(out, aws.StringValue(subnet.SubnetId))

internal/aws/ec2.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ type EC2API interface {
4444
// GetSecurityGroupsByName retrieves securityGroups by securityGroupName(SecurityGroup names within vpc are unique)
4545
GetSecurityGroupsByName(context.Context, []string) ([]*ec2.SecurityGroup, error)
4646

47+
// GetClusterSubnets retrieves the subnets associated with the cluster, by matching tags
48+
GetClusterSubnets(string) ([]*ec2.Subnet, error)
49+
4750
// DeleteSecurityGroupByID delete securityGroup by securityGroupID
4851
DeleteSecurityGroupByID(context.Context, string) error
4952

@@ -164,6 +167,26 @@ func (c *Cloud) GetSubnetsByNameOrID(ctx context.Context, nameOrIDs []string) (s
164167
return
165168
}
166169

170+
func (c *Cloud) GetClusterSubnets(tagSubnetType string) ([]*ec2.Subnet, error) {
171+
in := &ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{
172+
{
173+
Name: aws.String("tag:kubernetes.io/cluster/" + c.clusterName),
174+
Values: aws.StringSlice([]string{"owned", "shared"}),
175+
},
176+
{
177+
Name: aws.String("tag:" + tagSubnetType),
178+
Values: aws.StringSlice([]string{"", "1"}),
179+
},
180+
}}
181+
182+
result, err := c.describeSubnetsHelper(in)
183+
if err != nil {
184+
return nil, err
185+
}
186+
187+
return result, nil
188+
}
189+
167190
func (c *Cloud) GetSecurityGroupsByName(ctx context.Context, names []string) (groups []*ec2.SecurityGroup, err error) {
168191
in := &ec2.DescribeSecurityGroupsInput{Filters: []*ec2.Filter{
169192
{
@@ -272,6 +295,16 @@ func (c *Cloud) describeSecurityGroupsHelper(params *ec2.DescribeSecurityGroupsI
272295
return results, err
273296
}
274297

298+
// describeSubnetsHelper is a helper to handle pagination for DescribeSubnets API call
299+
func (c *Cloud) describeSubnetsHelper(params *ec2.DescribeSubnetsInput) (result []*ec2.Subnet, err error) {
300+
err = c.ec2.DescribeSubnetsPages(params, func(output *ec2.DescribeSubnetsOutput, _ bool) bool {
301+
result = append(result, output.Subnets...)
302+
return true
303+
})
304+
305+
return result, err
306+
}
307+
275308
func (c *Cloud) describeInstancesHelper(params *ec2.DescribeInstancesInput) (result []*ec2.Reservation, err error) {
276309
err = c.ec2.DescribeInstancesPages(params, func(output *ec2.DescribeInstancesOutput, _ bool) bool {
277310
result = append(result, output.Reservations...)

internal/aws/ec2_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"errors"
66
"testing"
77

8+
"github.com/stretchr/testify/mock"
9+
810
"github.com/aws/aws-sdk-go/aws"
911
"github.com/aws/aws-sdk-go/service/ec2"
1012
"github.com/kubernetes-sigs/aws-alb-ingress-controller/mocks"
@@ -147,3 +149,122 @@ func TestCloud_RevokeSecurityGroupIngressWithContext(t *testing.T) {
147149
svc.AssertExpectations(t)
148150
})
149151
}
152+
153+
func TestCloud_GetClusterSubnets(t *testing.T) {
154+
clusterName := "clusterName"
155+
internalSubnet1 := &ec2.Subnet{
156+
SubnetId: aws.String("arn:aws:ec2:region:account-id:subnet/subnet-id1"),
157+
Tags: []*ec2.Tag{
158+
{
159+
Key: aws.String("kubernetes.io/cluster/" + clusterName),
160+
Value: aws.String("owned"),
161+
},
162+
{
163+
Key: aws.String("kubernetes.io/role/internal-elb"),
164+
Value: aws.String("1"),
165+
},
166+
},
167+
}
168+
internalSubnet2 := &ec2.Subnet{
169+
SubnetId: aws.String("arn:aws:ec2:region:account-id:subnet/subnet-id2"),
170+
Tags: []*ec2.Tag{
171+
{
172+
Key: aws.String("kubernetes.io/cluster/" + clusterName),
173+
Value: aws.String("owned"),
174+
},
175+
{
176+
Key: aws.String("kubernetes.io/role/internal-elb"),
177+
Value: aws.String(""),
178+
},
179+
},
180+
}
181+
publicSubnet := &ec2.Subnet{
182+
SubnetId: aws.String("arn:aws:ec2:region:account-id:subnet/subnet-id3"),
183+
Tags: []*ec2.Tag{
184+
{
185+
Key: aws.String("kubernetes.io/cluster/" + clusterName),
186+
Value: aws.String("shared"),
187+
},
188+
{
189+
Key: aws.String("kubernetes.io/role/elb"),
190+
Value: aws.String("1"),
191+
},
192+
},
193+
}
194+
195+
for _, tc := range []struct {
196+
Name string
197+
DescribeSubnetsOutput *ec2.DescribeSubnetsOutput
198+
DescribeSubnetsError error
199+
TagSubnetType string
200+
ExpectedResult []*ec2.Subnet
201+
ExpectedError error
202+
}{
203+
{
204+
Name: "No subnets returned",
205+
TagSubnetType: TagNameSubnetInternalELB,
206+
DescribeSubnetsOutput: &ec2.DescribeSubnetsOutput{
207+
NextToken: nil,
208+
Subnets: []*ec2.Subnet{},
209+
},
210+
},
211+
{
212+
Name: "Two internal subnets returned",
213+
TagSubnetType: TagNameSubnetInternalELB,
214+
DescribeSubnetsOutput: &ec2.DescribeSubnetsOutput{
215+
NextToken: nil,
216+
Subnets: []*ec2.Subnet{internalSubnet1, internalSubnet2},
217+
},
218+
ExpectedResult: []*ec2.Subnet{internalSubnet1, internalSubnet2},
219+
},
220+
{
221+
Name: "One public subnet returned",
222+
TagSubnetType: TagNameSubnetPublicELB,
223+
DescribeSubnetsOutput: &ec2.DescribeSubnetsOutput{
224+
NextToken: nil,
225+
Subnets: []*ec2.Subnet{publicSubnet},
226+
},
227+
ExpectedResult: []*ec2.Subnet{publicSubnet},
228+
},
229+
{
230+
Name: "Error from API call",
231+
TagSubnetType: TagNameSubnetPublicELB,
232+
DescribeSubnetsOutput: &ec2.DescribeSubnetsOutput{
233+
NextToken: nil,
234+
Subnets: []*ec2.Subnet{},
235+
},
236+
DescribeSubnetsError: errors.New("Some API error"),
237+
ExpectedError: errors.New("Some API error"),
238+
},
239+
} {
240+
t.Run(tc.Name, func(t *testing.T) {
241+
svc := &mocks.EC2API{}
242+
243+
svc.On("DescribeSubnetsPages",
244+
&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{
245+
{
246+
Name: aws.String("tag:kubernetes.io/cluster/" + clusterName),
247+
Values: aws.StringSlice([]string{"owned", "shared"}),
248+
},
249+
{
250+
Name: aws.String("tag:" + tc.TagSubnetType),
251+
Values: aws.StringSlice([]string{"", "1"}),
252+
}},
253+
},
254+
mock.AnythingOfType("func(*ec2.DescribeSubnetsOutput, bool) bool"),
255+
).Return(tc.DescribeSubnetsError).Run(func(args mock.Arguments) {
256+
arg := args.Get(1).(func(*ec2.DescribeSubnetsOutput, bool) bool)
257+
arg(tc.DescribeSubnetsOutput, false)
258+
})
259+
260+
cloud := &Cloud{
261+
clusterName: clusterName,
262+
ec2: svc,
263+
}
264+
subnets, err := cloud.GetClusterSubnets(tc.TagSubnetType)
265+
assert.Equal(t, tc.ExpectedResult, subnets)
266+
assert.Equal(t, tc.ExpectedError, err)
267+
svc.AssertExpectations(t)
268+
})
269+
}
270+
}

internal/aws/rgt.go

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,9 @@ package aws
22

33
import (
44
"context"
5-
"strings"
6-
7-
"github.com/aws/aws-sdk-go/service/ec2"
85

96
"github.com/aws/aws-sdk-go/aws"
10-
117
"github.com/aws/aws-sdk-go/service/resourcegroupstaggingapi"
12-
13-
util "github.com/kubernetes-sigs/aws-alb-ingress-controller/pkg/util/types"
148
)
159

1610
const (
@@ -20,8 +14,6 @@ const (
2014
)
2115

2216
type ResourceGroupsTaggingAPIAPI interface {
23-
GetClusterSubnets() (map[string]util.EC2Tags, error)
24-
2517
// GetResourcesByFilters fetches resources ARNs by tagFilters and 0 or more resourceTypesFilters
2618
GetResourcesByFilters(tagFilters map[string][]string, resourceTypeFilters ...string) ([]string, error)
2719

@@ -36,76 +28,6 @@ func (c *Cloud) UntagResourcesWithContext(ctx context.Context, i *resourcegroups
3628
return c.rgt.UntagResourcesWithContext(ctx, i)
3729
}
3830

39-
// GetClusterSubnets looks up all subnets in AWS that are tagged for the cluster.
40-
func (c *Cloud) GetClusterSubnets() (map[string]util.EC2Tags, error) {
41-
subnets := make(map[string]util.EC2Tags)
42-
43-
paramSets := []*resourcegroupstaggingapi.GetResourcesInput{
44-
{
45-
ResourcesPerPage: aws.Int64(50),
46-
ResourceTypeFilters: []*string{
47-
aws.String("ec2"),
48-
},
49-
TagFilters: []*resourcegroupstaggingapi.TagFilter{
50-
{
51-
Key: aws.String("kubernetes.io/role/internal-elb"),
52-
Values: []*string{aws.String(""), aws.String("1")},
53-
},
54-
{
55-
Key: aws.String("kubernetes.io/cluster/" + c.clusterName),
56-
Values: []*string{aws.String("owned"), aws.String("shared")},
57-
},
58-
},
59-
},
60-
{
61-
ResourcesPerPage: aws.Int64(50),
62-
ResourceTypeFilters: []*string{
63-
aws.String("ec2"),
64-
},
65-
TagFilters: []*resourcegroupstaggingapi.TagFilter{
66-
{
67-
Key: aws.String("kubernetes.io/role/elb"),
68-
Values: []*string{aws.String(""), aws.String("1")},
69-
},
70-
{
71-
Key: aws.String("kubernetes.io/cluster/" + c.clusterName),
72-
Values: []*string{aws.String("owned"), aws.String("shared")},
73-
},
74-
},
75-
},
76-
}
77-
78-
for _, paramSet := range paramSets {
79-
err := c.rgt.GetResourcesPages(paramSet, func(page *resourcegroupstaggingapi.GetResourcesOutput, lastPage bool) bool {
80-
if page == nil {
81-
return false
82-
}
83-
for _, rtm := range page.ResourceTagMappingList {
84-
switch {
85-
case strings.Contains(*rtm.ResourceARN, ":subnet/"):
86-
subnets[*rtm.ResourceARN] = rgtTagAsEC2Tag(rtm.Tags)
87-
}
88-
}
89-
return true
90-
})
91-
if err != nil {
92-
return nil, err
93-
}
94-
}
95-
96-
return subnets, nil
97-
}
98-
99-
func rgtTagAsEC2Tag(in []*resourcegroupstaggingapi.Tag) (tags util.EC2Tags) {
100-
for _, t := range in {
101-
tags = append(tags, &ec2.Tag{
102-
Key: t.Key,
103-
Value: t.Value,
104-
})
105-
}
106-
return tags
107-
}
108-
10931
func (c *Cloud) GetResourcesByFilters(tagFilters map[string][]string, resourceTypeFilters ...string) ([]string, error) {
11032
var awsTagFilters []*resourcegroupstaggingapi.TagFilter
11133
for k, v := range tagFilters {

0 commit comments

Comments
 (0)