Skip to content

Refactor custom AWS endpoint resolver #2270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"sigs.k8s.io/aws-load-balancer-controller/controllers/ingress"
"sigs.k8s.io/aws-load-balancer-controller/controllers/service"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
"sigs.k8s.io/aws-load-balancer-controller/pkg/config"
"sigs.k8s.io/aws-load-balancer-controller/pkg/inject"
Expand Down Expand Up @@ -173,7 +172,6 @@ func loadControllerConfig() (config.ControllerConfig, error) {
controllerCFG := config.ControllerConfig{
AWSConfig: aws.CloudConfig{
ThrottleConfig: defaultAWSThrottleCFG,
AWSEndpointResolver: &endpoints.AWSEndpointResolver{},
},
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/aws/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"os"
epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/metrics"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
Expand Down Expand Up @@ -43,7 +44,8 @@ type Cloud interface {

// NewCloud constructs new Cloud implementation.
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer) (Cloud, error) {
metadataCFG := aws.NewConfig().WithEndpointResolver(cfg.AWSEndpointResolver)
endpointsResolver := epresolver.NewResolver(cfg.AWSEndpoints)
metadataCFG := aws.NewConfig().WithEndpointResolver(endpointsResolver)
metadataSess := session.Must(session.NewSession(metadataCFG))
metadata := services.NewEC2Metadata(metadataSess)
if len(cfg.VpcID) == 0 {
Expand All @@ -69,7 +71,7 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer) (Cloud,
}
cfg.Region = region
}
awsCFG := aws.NewConfig().WithRegion(cfg.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).WithMaxRetries(cfg.MaxRetries).WithEndpointResolver(cfg.AWSEndpointResolver)
awsCFG := aws.NewConfig().WithRegion(cfg.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).WithMaxRetries(cfg.MaxRetries).WithEndpointResolver(endpointsResolver)
sess := session.Must(session.NewSession(awsCFG))
injectUserAgent(&sess.Handlers)

Expand Down
7 changes: 3 additions & 4 deletions pkg/aws/cloud_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package aws

import (
"github.com/spf13/pflag"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
)

Expand Down Expand Up @@ -30,14 +29,14 @@ type CloudConfig struct {
// Max retries configuration for AWS APIs
MaxRetries int

// AWS endpoint configuration
AWSEndpointResolver *endpoints.AWSEndpointResolver
// AWS endpoints configuration
AWSEndpoints map[string]string
}

func (cfg *CloudConfig) BindFlags(fs *pflag.FlagSet) {
fs.StringVar(&cfg.Region, flagAWSRegion, defaultRegion, "AWS Region for the kubernetes cluster")
fs.Var(cfg.ThrottleConfig, flagAWSAPIThrottle, "throttle settings for AWS APIs, format: serviceID1:operationRegex1=rate:burst,serviceID2:operationRegex2=rate:burst")
fs.StringVar(&cfg.VpcID, flagAWSVpcID, defaultVpcID, "AWS VPC ID for the Kubernetes cluster")
fs.IntVar(&cfg.MaxRetries, flagAWSMaxRetries, defaultAPIMaxRetries, "Maximum retries for AWS APIs")
fs.Var(cfg.AWSEndpointResolver, flagAWSAPIEndpoints, "Custom AWS endpoint configuration, format: serviceID1=URL1,serviceID2=URL2")
fs.StringToStringVar(&cfg.AWSEndpoints, flagAWSAPIEndpoints, nil, "Custom AWS endpoint configuration, format: serviceID1=URL1,serviceID2=URL2")
}
76 changes: 11 additions & 65 deletions pkg/aws/endpoints/resolver.go
Original file line number Diff line number Diff line change
@@ -1,84 +1,30 @@
package endpoints

import (
"fmt"
"net/url"
"sort"
"strings"

awsendpoints "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/pkg/errors"
"github.com/spf13/pflag"
)

var _ pflag.Value = &AWSEndpointResolver{}

// AWSEndpointResolver is an AWS endpoints.Resolver that allows to customize AWS API endpoints.
// It can be configured using the following format "${AWSServiceID}=${URL}"
// e.g. "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com"
type AWSEndpointResolver struct {
configuration map[string]string
}

func (c *AWSEndpointResolver) String() string {
if c == nil {
return ""
}

var configs []string
var serviceIDs []string
for serviceID := range c.configuration {
serviceIDs = append(serviceIDs, serviceID)
}
sort.Strings(serviceIDs)
for _, serviceID := range serviceIDs {
configs = append(configs, fmt.Sprintf("%s=%s", serviceID, c.configuration[serviceID]))
func NewResolver(configuration map[string]string) *resolver {
return &resolver{
configuration: configuration,
}
return strings.Join(configs, ",")
}

func (c *AWSEndpointResolver) Set(val string) error {
configurationOverride := make(map[string]string)
var _ awsendpoints.Resolver = &resolver{}

if val != "" {
configPairs := strings.Split(val, ",")
for _, pair := range configPairs {
kv := strings.Split(pair, "=")
if len(kv) != 2 {
return errors.Errorf("%s must be formatted as serviceID=URL", pair)
}
serviceID := kv[0]
urlStr := kv[1]
url, err := url.Parse(urlStr)
if err != nil {
return errors.Errorf("%s must be a valid url", urlStr)
}
if !url.IsAbs() {
return errors.Errorf("%s must be an absolute url", urlStr)
}
configurationOverride[serviceID] = url.String()
}
}

if c.configuration == nil {
c.configuration = make(map[string]string)
}
for k, v := range configurationOverride {
c.configuration[k] = v
}
return nil
}

func (c *AWSEndpointResolver) Type() string {
return "awsEndpointResolver"
// resolver is an AWS endpoints.Resolver that allows to customize AWS API endpoints.
// It can be configured using the following format "${AWSServiceID}=${URL}"
// e.g. "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com"
type resolver struct {
configuration map[string]string
}

func (c *AWSEndpointResolver) EndpointFor(service, region string, opts ...func(*awsendpoints.Options)) (awsendpoints.ResolvedEndpoint, error) {
func (c *resolver) EndpointFor(service, region string, opts ...func(*awsendpoints.Options)) (awsendpoints.ResolvedEndpoint, error) {
customEndpoint := c.configuration[service]
if len(customEndpoint) != 0 {
return awsendpoints.ResolvedEndpoint{
URL: customEndpoint,
}, nil
}
return awsendpoints.DefaultResolver().EndpointFor(service, region, opts...)
}
}
146 changes: 3 additions & 143 deletions pkg/aws/endpoints/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,155 +4,15 @@ import (
"testing"

awsendpoints "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

func TestAWSEndpointResolver_String(t *testing.T) {
type fields struct {
configuration map[string]string
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "non-empty configuration",
fields: fields{
configuration: map[string]string{
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
awsendpoints.ElasticloadbalancingServiceID: "https://elbv2.domain.com",
},
},
want: "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com",
},
{
name: "nil configuration",
fields: fields{
configuration: nil,
},
want: "",
},
{
name: "empty configuration",
fields: fields{
configuration: nil,
},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AWSEndpointResolver{
configuration: tt.fields.configuration,
}
got := c.String()
assert.Equal(t, tt.want, got)
})
}
}

func TestAWSEndpointResolver_Set(t *testing.T) {
type fields struct {
configuration map[string]string
}
type args struct {
val string
}
tests := []struct {
name string
fields fields
args args
want AWSEndpointResolver
wantErr error
}{
{
name: "when default value is nil",
fields: fields{
configuration: nil,
},
args: args{
val: "ec2=https://ec2.domain.com,elasticloadbalancing=https://elbv2.domain.com",
},
want: AWSEndpointResolver{
configuration: map[string]string{
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
awsendpoints.ElasticloadbalancingServiceID: "https://elbv2.domain.com",
},
},
},
{
name: "when val is empty",
fields: fields{
configuration: map[string]string{},
},
args: args{
val: "",
},
want: AWSEndpointResolver{
configuration: map[string]string{},
},
},
{
name: "when val is not valid format - case 1",
fields: fields{
configuration: map[string]string{},
},
args: args{
val: "a=b=c",
},
wantErr: errors.Errorf("a=b=c must be formatted as serviceID=URL"),
},
{
name: "when url is not absolute",
fields: fields{
configuration: map[string]string{},
},
args: args{
val: "a=/relative/url",
},
wantErr: errors.Errorf("/relative/url must be an absolute url"),
},
{
name: "when url is invalid",
fields: fields{
configuration: map[string]string{},
},
args: args{
val: "a=invalid\turl",
},
wantErr: errors.Errorf("invalid\turl must be a valid url"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AWSEndpointResolver{
configuration: tt.fields.configuration,
}
err := c.Set(tt.args.val)
if tt.wantErr != nil {
assert.EqualError(t, err, tt.wantErr.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, *c)
}
})
}
}

func TestAWSEndpointResolver_Type(t *testing.T) {
c := &AWSEndpointResolver{}
got := c.Type()
assert.Equal(t, "awsEndpointResolver", got)
}

func TestAWSEndpointResolver_EndpointFor(t *testing.T) {
configuration := map[string]string{
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
awsendpoints.Ec2ServiceID: "https://ec2.domain.com",
awsendpoints.ElasticloadbalancingServiceID: "https://elbv2.domain.com",
}
c := &AWSEndpointResolver{
c := &resolver{
configuration: configuration,
}

Expand Down Expand Up @@ -185,7 +45,7 @@ func TestAWSEndpointResolver_EndpointFor(t *testing.T) {
want: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, err := c.EndpointFor(tt.args.val, testRegion)
Expand Down
12 changes: 6 additions & 6 deletions pkg/targetgroupbinding/networking_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,8 @@ func Test_defaultNetworkingManager_computeRestrictedIngressPermissionsPerSG(t *t
{
Permission: ec2sdk.IpPermission{
IpProtocol: awssdk.String("tcp"),
FromPort: nil,
ToPort: nil,
FromPort: nil,
ToPort: nil,
UserIdGroupPairs: []*ec2sdk.UserIdGroupPair{
{GroupId: awssdk.String("group-1")},
},
Expand Down Expand Up @@ -914,8 +914,8 @@ func Test_defaultNetworkingManager_computeRestrictedIngressPermissionsPerSG(t *t
{
Permission: ec2sdk.IpPermission{
IpProtocol: awssdk.String("tcp"),
FromPort: nil,
ToPort: nil,
FromPort: nil,
ToPort: nil,
UserIdGroupPairs: []*ec2sdk.UserIdGroupPair{
{GroupId: awssdk.String("group-1")},
},
Expand All @@ -928,8 +928,8 @@ func Test_defaultNetworkingManager_computeRestrictedIngressPermissionsPerSG(t *t
{
Permission: ec2sdk.IpPermission{
IpProtocol: awssdk.String("tcp"),
FromPort: nil,
ToPort: nil,
FromPort: nil,
ToPort: nil,
UserIdGroupPairs: []*ec2sdk.UserIdGroupPair{
{GroupId: awssdk.String("group-2")},
},
Expand Down
2 changes: 0 additions & 2 deletions test/framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"k8s.io/client-go/rest"
elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
"sigs.k8s.io/aws-load-balancer-controller/test/framework/controller"
"sigs.k8s.io/aws-load-balancer-controller/test/framework/helm"
Expand Down Expand Up @@ -77,7 +76,6 @@ func InitFramework() (*Framework, error) {
VpcID: globalOptions.AWSVPCID,
MaxRetries: 3,
ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(),
AWSEndpointResolver: &endpoints.AWSEndpointResolver{},
}, nil)
if err != nil {
return nil, err
Expand Down