Skip to content

Commit 862890a

Browse files
papigerskishorj
andauthored
add custom aws endpoint configuration (#2179)
Co-authored-by: Kishor Joshi <[email protected]>
1 parent b40c98b commit 862890a

File tree

9 files changed

+313
-4
lines changed

9 files changed

+313
-4
lines changed

docs/deploy/configurations.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ Currently, you can set only 1 namespace to watch in this flag. See [this Kuberne
7171
|aws-max-retries | int | 10 | Maximum retries for AWS APIs |
7272
|aws-region | string | [instance metadata](#instance-metadata) | AWS Region for the kubernetes cluster |
7373
|aws-vpc-id | string | [instance metadata](#instance-metadata) | AWS VPC ID for the Kubernetes cluster |
74+
|aws-api-endpoints | AWS API Endpoints Config | | AWS API endpoints mapping, format: serviceID1=URL1,serviceID2=URL2 |
7475
|cluster-name | string | | Kubernetes cluster name|
7576
|default-tags | stringMap | | AWS Tags that will be applied to all AWS resources managed by this controller. Specified Tags takes highest priority |
7677
|default-ssl-policy | string | ELBSecurityPolicy-2016-08 | Default SSL Policy that will be applied to all Ingresses or Services that do not have the SSL Policy annotation |

helm/aws-load-balancer-controller/templates/deployment.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ spec:
5959
{{- if .Values.vpcId }}
6060
- --aws-vpc-id={{ .Values.vpcId }}
6161
{{- end }}
62+
{{- if .Values.awsApiEndpoints }}
63+
- --aws-api-endpoints={{ .Values.awsApiEndpoints }}
64+
{{- end }}
6265
{{- if .Values.awsMaxRetries }}
6366
- --aws-max-retries={{ .Values.awsMaxRetries }}
6467
{{- end }}

helm/aws-load-balancer-controller/values.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ region:
9393
# The VPC ID for the Kubernetes cluster. Set this manually when your pods are unable to use the metadata service to determine this automatically
9494
vpcId:
9595

96+
# Custom AWS API Endpoints (serviceID1=URL1,serviceID2=URL2)
97+
awsApiEndpoints:
98+
9699
# Maximum retries for AWS APIs (default 10)
97100
awsMaxRetries:
98101

main.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"sigs.k8s.io/aws-load-balancer-controller/controllers/ingress"
3232
"sigs.k8s.io/aws-load-balancer-controller/controllers/service"
3333
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws"
34+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
3435
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
3536
"sigs.k8s.io/aws-load-balancer-controller/pkg/config"
3637
"sigs.k8s.io/aws-load-balancer-controller/pkg/inject"
@@ -170,7 +171,10 @@ func main() {
170171
func loadControllerConfig() (config.ControllerConfig, error) {
171172
defaultAWSThrottleCFG := throttle.NewDefaultServiceOperationsThrottleConfig()
172173
controllerCFG := config.ControllerConfig{
173-
AWSConfig: aws.CloudConfig{ThrottleConfig: defaultAWSThrottleCFG},
174+
AWSConfig: aws.CloudConfig{
175+
ThrottleConfig: defaultAWSThrottleCFG,
176+
AWSEndpointResolver: &endpoints.AWSEndpointResolver{},
177+
},
174178
}
175179

176180
fs := pflag.NewFlagSet("", pflag.ExitOnError)

pkg/aws/cloud.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ type Cloud interface {
4343

4444
// NewCloud constructs new Cloud implementation.
4545
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer) (Cloud, error) {
46-
metadataSess := session.Must(session.NewSession(aws.NewConfig()))
46+
metadataCFG := aws.NewConfig().WithEndpointResolver(cfg.AWSEndpointResolver)
47+
metadataSess := session.Must(session.NewSession(metadataCFG))
4748
metadata := services.NewEC2Metadata(metadataSess)
4849
if len(cfg.VpcID) == 0 {
4950
vpcId, err := metadata.VpcID()
@@ -68,8 +69,7 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer) (Cloud,
6869
}
6970
cfg.Region = region
7071
}
71-
72-
awsCFG := aws.NewConfig().WithRegion(cfg.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).WithMaxRetries(cfg.MaxRetries)
72+
awsCFG := aws.NewConfig().WithRegion(cfg.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).WithMaxRetries(cfg.MaxRetries).WithEndpointResolver(cfg.AWSEndpointResolver)
7373
sess := session.Must(session.NewSession(awsCFG))
7474
injectUserAgent(&sess.Handlers)
7575

pkg/aws/cloud_config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package aws
22

33
import (
44
"github.com/spf13/pflag"
5+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
56
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
67
)
78

@@ -10,6 +11,7 @@ const (
1011
flagAWSAPIThrottle = "aws-api-throttle"
1112
flagAWSVpcID = "aws-vpc-id"
1213
flagAWSMaxRetries = "aws-max-retries"
14+
flagAWSAPIEndpoints = "aws-api-endpoints"
1315
defaultVpcID = ""
1416
defaultRegion = ""
1517
defaultAPIMaxRetries = 10
@@ -27,11 +29,15 @@ type CloudConfig struct {
2729

2830
// Max retries configuration for AWS APIs
2931
MaxRetries int
32+
33+
// AWS endpoint configuration
34+
AWSEndpointResolver *endpoints.AWSEndpointResolver
3035
}
3136

3237
func (cfg *CloudConfig) BindFlags(fs *pflag.FlagSet) {
3338
fs.StringVar(&cfg.Region, flagAWSRegion, defaultRegion, "AWS Region for the kubernetes cluster")
3439
fs.Var(cfg.ThrottleConfig, flagAWSAPIThrottle, "throttle settings for AWS APIs, format: serviceID1:operationRegex1=rate:burst,serviceID2:operationRegex2=rate:burst")
3540
fs.StringVar(&cfg.VpcID, flagAWSVpcID, defaultVpcID, "AWS VPC ID for the Kubernetes cluster")
3641
fs.IntVar(&cfg.MaxRetries, flagAWSMaxRetries, defaultAPIMaxRetries, "Maximum retries for AWS APIs")
42+
fs.Var(cfg.AWSEndpointResolver, flagAWSAPIEndpoints, "Custom AWS endpoint configuration, format: serviceID1=URL1,serviceID2=URL2")
3743
}

pkg/aws/endpoints/resolver.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package endpoints
2+
3+
import (
4+
"fmt"
5+
"net/url"
6+
"sort"
7+
"strings"
8+
9+
awsendpoints "github.com/aws/aws-sdk-go/aws/endpoints"
10+
"github.com/pkg/errors"
11+
"github.com/spf13/pflag"
12+
)
13+
14+
var _ pflag.Value = &AWSEndpointResolver{}
15+
16+
// AWSEndpointResolver is an AWS endpoints.Resolver that allows to customize AWS API endpoints.
17+
// It can be configured using the following format "${AWSServiceID}=${URL}"
18+
// e.g. "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com"
19+
type AWSEndpointResolver struct {
20+
configuration map[string]string
21+
}
22+
23+
func (c *AWSEndpointResolver) String() string {
24+
if c == nil {
25+
return ""
26+
}
27+
28+
var configs []string
29+
var serviceIDs []string
30+
for serviceID := range c.configuration {
31+
serviceIDs = append(serviceIDs, serviceID)
32+
}
33+
sort.Strings(serviceIDs)
34+
for _, serviceID := range serviceIDs {
35+
configs = append(configs, fmt.Sprintf("%s=%s", serviceID, c.configuration[serviceID]))
36+
}
37+
return strings.Join(configs, ",")
38+
}
39+
40+
func (c *AWSEndpointResolver) Set(val string) error {
41+
configurationOverride := make(map[string]string)
42+
43+
if val != "" {
44+
configPairs := strings.Split(val, ",")
45+
for _, pair := range configPairs {
46+
kv := strings.Split(pair, "=")
47+
if len(kv) != 2 {
48+
return errors.Errorf("%s must be formatted as serviceID=URL", pair)
49+
}
50+
serviceID := kv[0]
51+
urlStr := kv[1]
52+
url, err := url.Parse(urlStr)
53+
if err != nil {
54+
return errors.Errorf("%s must be a valid url", urlStr)
55+
}
56+
if !url.IsAbs() {
57+
return errors.Errorf("%s must be an absolute url", urlStr)
58+
}
59+
configurationOverride[serviceID] = url.String()
60+
}
61+
}
62+
63+
if c.configuration == nil {
64+
c.configuration = make(map[string]string)
65+
}
66+
for k, v := range configurationOverride {
67+
c.configuration[k] = v
68+
}
69+
return nil
70+
}
71+
72+
func (c *AWSEndpointResolver) Type() string {
73+
return "awsEndpointResolver"
74+
}
75+
76+
func (c *AWSEndpointResolver) EndpointFor(service, region string, opts ...func(*awsendpoints.Options)) (awsendpoints.ResolvedEndpoint, error) {
77+
customEndpoint := c.configuration[service]
78+
if len(customEndpoint) != 0 {
79+
return awsendpoints.ResolvedEndpoint{
80+
URL: customEndpoint,
81+
}, nil
82+
}
83+
return awsendpoints.DefaultResolver().EndpointFor(service, region, opts...)
84+
}

pkg/aws/endpoints/resolver_test.go

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package endpoints
2+
3+
import (
4+
"testing"
5+
6+
awsendpoints "github.com/aws/aws-sdk-go/aws/endpoints"
7+
"github.com/pkg/errors"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestAWSEndpointResolver_String(t *testing.T) {
12+
type fields struct {
13+
configuration map[string]string
14+
}
15+
tests := []struct {
16+
name string
17+
fields fields
18+
want string
19+
}{
20+
{
21+
name: "non-empty configuration",
22+
fields: fields{
23+
configuration: map[string]string{
24+
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
25+
awsendpoints.ElasticloadbalancingServiceID: "https://elbv2.domain.com",
26+
},
27+
},
28+
want: "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com",
29+
},
30+
{
31+
name: "nil configuration",
32+
fields: fields{
33+
configuration: nil,
34+
},
35+
want: "",
36+
},
37+
{
38+
name: "empty configuration",
39+
fields: fields{
40+
configuration: nil,
41+
},
42+
want: "",
43+
},
44+
}
45+
for _, tt := range tests {
46+
t.Run(tt.name, func(t *testing.T) {
47+
c := &AWSEndpointResolver{
48+
configuration: tt.fields.configuration,
49+
}
50+
got := c.String()
51+
assert.Equal(t, tt.want, got)
52+
})
53+
}
54+
}
55+
56+
func TestAWSEndpointResolver_Set(t *testing.T) {
57+
type fields struct {
58+
configuration map[string]string
59+
}
60+
type args struct {
61+
val string
62+
}
63+
tests := []struct {
64+
name string
65+
fields fields
66+
args args
67+
want AWSEndpointResolver
68+
wantErr error
69+
}{
70+
{
71+
name: "when default value is nil",
72+
fields: fields{
73+
configuration: nil,
74+
},
75+
args: args{
76+
val: "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com",
77+
},
78+
want: AWSEndpointResolver{
79+
configuration: map[string]string{
80+
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
81+
awsendpoints.ElasticloadbalancingServiceID: "https://elbv2.domain.com",
82+
},
83+
},
84+
},
85+
{
86+
name: "when val is empty",
87+
fields: fields{
88+
configuration: map[string]string{},
89+
},
90+
args: args{
91+
val: "",
92+
},
93+
want: AWSEndpointResolver{
94+
configuration: map[string]string{},
95+
},
96+
},
97+
{
98+
name: "when val is not valid format - case 1",
99+
fields: fields{
100+
configuration: map[string]string{},
101+
},
102+
args: args{
103+
val: "a=b=c",
104+
},
105+
wantErr: errors.Errorf("a=b=c must be formatted as serviceID=URL"),
106+
},
107+
{
108+
name: "when url is not absolute",
109+
fields: fields{
110+
configuration: map[string]string{},
111+
},
112+
args: args{
113+
val: "a=/relative/url",
114+
},
115+
wantErr: errors.Errorf("/relative/url must be an absolute url"),
116+
},
117+
{
118+
name: "when url is invalid",
119+
fields: fields{
120+
configuration: map[string]string{},
121+
},
122+
args: args{
123+
val: "a=invalid\turl",
124+
},
125+
wantErr: errors.Errorf("invalid\turl must be a valid url"),
126+
},
127+
}
128+
for _, tt := range tests {
129+
t.Run(tt.name, func(t *testing.T) {
130+
c := &AWSEndpointResolver{
131+
configuration: tt.fields.configuration,
132+
}
133+
err := c.Set(tt.args.val)
134+
if tt.wantErr != nil {
135+
assert.EqualError(t, err, tt.wantErr.Error())
136+
} else {
137+
assert.NoError(t, err)
138+
assert.Equal(t, tt.want, *c)
139+
}
140+
})
141+
}
142+
}
143+
144+
func TestAWSEndpointResolver_Type(t *testing.T) {
145+
c := &AWSEndpointResolver{}
146+
got := c.Type()
147+
assert.Equal(t, "awsEndpointResolver", got)
148+
}
149+
150+
func TestAWSEndpointResolver_EndpointFor(t *testing.T) {
151+
configuration := map[string]string{
152+
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
153+
awsendpoints.ElasticloadbalancingServiceID: "https://elbv2.domain.com",
154+
}
155+
c := &AWSEndpointResolver{
156+
configuration: configuration,
157+
}
158+
159+
testRegion := "region"
160+
161+
type args struct {
162+
val string
163+
}
164+
165+
tests := []struct {
166+
name string
167+
args args
168+
want *awsendpoints.ResolvedEndpoint
169+
wantErr error
170+
}{
171+
{
172+
name: "when custom endpoint is configured",
173+
args: args{
174+
val: awsendpoints.Ec2ServiceID,
175+
},
176+
want: &awsendpoints.ResolvedEndpoint{
177+
URL: configuration[awsendpoints.Ec2ServiceID],
178+
},
179+
},
180+
{
181+
name: "when custom endpoint is unconfigured",
182+
args: args{
183+
val: awsendpoints.WafServiceID,
184+
},
185+
want: nil,
186+
},
187+
}
188+
189+
for _, tt := range tests {
190+
t.Run(tt.name, func(t *testing.T) {
191+
res, err := c.EndpointFor(tt.args.val, testRegion)
192+
if tt.wantErr != nil {
193+
assert.EqualError(t, err, tt.wantErr.Error())
194+
} else {
195+
assert.NoError(t, err)
196+
if tt.want != nil {
197+
assert.Equal(t, *tt.want, res)
198+
} else {
199+
defaultEndpoint, err := awsendpoints.DefaultResolver().EndpointFor(tt.args.val, testRegion)
200+
assert.NoError(t, err)
201+
assert.Equal(t, defaultEndpoint, res)
202+
}
203+
}
204+
})
205+
}
206+
}

0 commit comments

Comments
 (0)