Skip to content

Commit 6f0b4cc

Browse files
authored
Merge pull request #2004 from kishorj/vpc-uts
Add unit tests for VPC resolver
2 parents 4cb7510 + 87cefb1 commit 6f0b4cc

File tree

2 files changed

+126
-2
lines changed

2 files changed

+126
-2
lines changed

pkg/networking/vpc_resolver.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type defaultVPCResolver struct {
3434
}
3535

3636
func (r *defaultVPCResolver) ResolveCIDRs(ctx context.Context) ([]string, error) {
37-
vpcs, err := r.ec2Client.DescribeVpcs(&ec2.DescribeVpcsInput{
37+
vpcs, err := r.ec2Client.DescribeVpcsWithContext(ctx, &ec2.DescribeVpcsInput{
3838
VpcIds: []*string{awssdk.String(r.vpcID)},
3939
})
4040
if err != nil {
@@ -51,5 +51,4 @@ func (r *defaultVPCResolver) ResolveCIDRs(ctx context.Context) ([]string, error)
5151
}
5252

5353
return vpcCIDRs, nil
54-
5554
}

pkg/networking/vpc_resolver_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package networking
2+
3+
import (
4+
"context"
5+
awssdk "github.com/aws/aws-sdk-go/aws"
6+
ec2sdk "github.com/aws/aws-sdk-go/service/ec2"
7+
"github.com/golang/mock/gomock"
8+
"github.com/pkg/errors"
9+
"github.com/stretchr/testify/assert"
10+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
11+
"testing"
12+
)
13+
14+
func Test_defaultVPCResolver_ResolveCIDRs(t *testing.T) {
15+
type descriveVpcsCall struct {
16+
input *ec2sdk.DescribeVpcsInput
17+
output *ec2sdk.DescribeVpcsOutput
18+
err error
19+
}
20+
tests := []struct {
21+
name string
22+
vpcID string
23+
want []string
24+
wantErr error
25+
descriveVpcsCall descriveVpcsCall
26+
}{
27+
{
28+
name: "vpc cidr discovery",
29+
vpcID: "vpc-01xxx2",
30+
want: []string{"192.160.0.0/16"},
31+
wantErr: nil,
32+
descriveVpcsCall: descriveVpcsCall{
33+
input: &ec2sdk.DescribeVpcsInput{
34+
VpcIds: []*string{awssdk.String("vpc-01xxx2")},
35+
},
36+
output: &ec2sdk.DescribeVpcsOutput{
37+
Vpcs: []*ec2sdk.Vpc{
38+
{
39+
CidrBlockAssociationSet: []*ec2sdk.VpcCidrBlockAssociation{
40+
{
41+
CidrBlock: awssdk.String("192.160.0.0/16"),
42+
},
43+
},
44+
},
45+
},
46+
},
47+
},
48+
},
49+
{
50+
name: "unable to describe VPC",
51+
vpcID: "vpc-01xxx3",
52+
wantErr: errors.Wrapf(errors.New("aws error"), "unable to describe VPC"),
53+
descriveVpcsCall: descriveVpcsCall{
54+
input: &ec2sdk.DescribeVpcsInput{
55+
VpcIds: []*string{awssdk.String("vpc-01xxx3")},
56+
},
57+
err: errors.New("aws error"),
58+
},
59+
},
60+
{
61+
name: "unable to find matching VPC",
62+
vpcID: "vpc-01xxx4",
63+
wantErr: errors.New("unable to find matching VPC \"vpc-01xxx4\""),
64+
descriveVpcsCall: descriveVpcsCall{
65+
input: &ec2sdk.DescribeVpcsInput{
66+
VpcIds: []*string{awssdk.String("vpc-01xxx4")},
67+
},
68+
output: &ec2sdk.DescribeVpcsOutput{},
69+
},
70+
},
71+
{
72+
name: "multiple CIDRs",
73+
vpcID: "vpc-01xxx2",
74+
want: []string{"192.160.0.0/16", "100.64.0.0/16", "100.65.0.0/16", "100.66.0.0/24"},
75+
wantErr: nil,
76+
descriveVpcsCall: descriveVpcsCall{
77+
input: &ec2sdk.DescribeVpcsInput{
78+
VpcIds: []*string{awssdk.String("vpc-01xxx2")},
79+
},
80+
output: &ec2sdk.DescribeVpcsOutput{
81+
Vpcs: []*ec2sdk.Vpc{
82+
{
83+
CidrBlockAssociationSet: []*ec2sdk.VpcCidrBlockAssociation{
84+
{
85+
CidrBlock: awssdk.String("192.160.0.0/16"),
86+
},
87+
{
88+
CidrBlock: awssdk.String("100.64.0.0/16"),
89+
},
90+
{
91+
CidrBlock: awssdk.String("100.65.0.0/16"),
92+
},
93+
{
94+
CidrBlock: awssdk.String("100.66.0.0/24"),
95+
},
96+
},
97+
},
98+
},
99+
},
100+
},
101+
},
102+
}
103+
for _, tt := range tests {
104+
t.Run(tt.name, func(t *testing.T) {
105+
ctrl := gomock.NewController(t)
106+
defer ctrl.Finish()
107+
108+
ec2Client := services.NewMockEC2(ctrl)
109+
ec2Client.EXPECT().DescribeVpcsWithContext(gomock.Any(), tt.descriveVpcsCall.input).Return(
110+
tt.descriveVpcsCall.output, tt.descriveVpcsCall.err)
111+
vpcResolver := &defaultVPCResolver{
112+
ec2Client: ec2Client,
113+
vpcID: tt.vpcID,
114+
}
115+
got, err := vpcResolver.ResolveCIDRs(context.Background())
116+
if tt.wantErr != nil {
117+
assert.EqualError(t, err, tt.wantErr.Error())
118+
} else {
119+
assert.NoError(t, err)
120+
assert.Equal(t, tt.want, got)
121+
}
122+
123+
})
124+
}
125+
}

0 commit comments

Comments
 (0)