Skip to content

Commit 32a6572

Browse files
committed
adding cache for cert discovery
1 parent 125e843 commit 32a6572

File tree

6 files changed

+501
-318
lines changed

6 files changed

+501
-318
lines changed

internal/alb/ls/cert_discovery.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package ls
2+
3+
import (
4+
"context"
5+
"reflect"
6+
"strings"
7+
"time"
8+
9+
"github.com/aws/aws-sdk-go/service/acm"
10+
"github.com/kubernetes-sigs/aws-alb-ingress-controller/internal/aws"
11+
"github.com/kubernetes-sigs/aws-alb-ingress-controller/internal/ingress/errors"
12+
"github.com/kubernetes-sigs/aws-alb-ingress-controller/internal/utils"
13+
"k8s.io/apimachinery/pkg/util/sets"
14+
)
15+
16+
const (
17+
// the domain names for imported certificate will be cached for 1 minute.(cache invalidation is hard problem right? :D)
18+
importedCertDomainsCacheDuration = 1 * time.Minute
19+
)
20+
21+
type CertDiscovery interface {
22+
// Discover will try to find valid certificates for each tlsHost.
23+
Discover(ctx context.Context, tlsHosts sets.String) ([]string, error)
24+
}
25+
26+
func NewACMCertDiscovery(cloud aws.CloudAPI) CertDiscovery {
27+
return &acmCertDiscovery{
28+
cloud: cloud,
29+
certDomainsCache: utils.NewCache(),
30+
}
31+
}
32+
33+
type acmCertDiscovery struct {
34+
cloud aws.CloudAPI
35+
certDomainsCache utils.Cache
36+
}
37+
38+
func (d *acmCertDiscovery) Discover(ctx context.Context, tlsHosts sets.String) ([]string, error) {
39+
domainsByCertArn, err := d.loadDomainsForCertificates(ctx)
40+
if err != nil {
41+
return nil, err
42+
}
43+
certArns := sets.NewString()
44+
for host := range tlsHosts {
45+
certArnsForHost := sets.NewString()
46+
for certArn, domains := range domainsByCertArn {
47+
for domain := range domains {
48+
if d.domainMatchesHost(domain, host) {
49+
certArnsForHost.Insert(certArn)
50+
break
51+
}
52+
}
53+
}
54+
if len(certArnsForHost) > 1 {
55+
return nil, errors.Errorf("multiple certificate found for host: %s, certARNs: %v", host, certArnsForHost.List())
56+
}
57+
if len(certArnsForHost) == 0 {
58+
return nil, errors.Errorf("none certificate found for host: %s", host)
59+
}
60+
certArns = certArns.Union(certArnsForHost)
61+
}
62+
return certArns.List(), nil
63+
}
64+
65+
func (d *acmCertDiscovery) loadDomainsForCertificates(ctx context.Context) (map[string]sets.String, error) {
66+
certSummaries, err := d.cloud.ListCertificates(ctx, &acm.ListCertificatesInput{
67+
CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued}),
68+
})
69+
if err != nil {
70+
return nil, err
71+
}
72+
domainsByCertArn := make(map[string]sets.String, len(certSummaries))
73+
for _, certSummary := range certSummaries {
74+
certArn := aws.StringValue(certSummary.CertificateArn)
75+
certDomains, err := d.loadDomainsForCertificate(ctx, certArn)
76+
if err != nil {
77+
return nil, err
78+
}
79+
domainsByCertArn[certArn] = certDomains
80+
}
81+
d.certDomainsCache.Shrink(sets.StringKeySet(domainsByCertArn))
82+
return domainsByCertArn, nil
83+
}
84+
85+
func (d *acmCertDiscovery) loadDomainsForCertificate(ctx context.Context, certArn string) (sets.String, error) {
86+
if domains, ok := d.certDomainsCache.Get(certArn); ok {
87+
return domains.(sets.String), nil
88+
}
89+
certDetail, err := d.cloud.DescribeCertificate(ctx, certArn)
90+
if err != nil {
91+
return nil, err
92+
}
93+
domains := sets.NewString(aws.StringValueSlice(certDetail.SubjectAlternativeNames)...)
94+
switch aws.StringValue(certDetail.Type) {
95+
case acm.CertificateTypeAmazonIssued, acm.CertificateTypePrivate:
96+
d.certDomainsCache.Set(certArn, domains, utils.CacheNoExpiration)
97+
case acm.CertificateTypeImported:
98+
d.certDomainsCache.Set(certArn, domains, importedCertDomainsCacheDuration)
99+
}
100+
return domains, nil
101+
}
102+
103+
func (d *acmCertDiscovery) domainMatchesHost(domainName string, tlsHost string) bool {
104+
if strings.HasPrefix(domainName, "*.") {
105+
ds := strings.Split(domainName, ".")
106+
hs := strings.Split(tlsHost, ".")
107+
108+
if len(ds) != len(hs) {
109+
return false
110+
}
111+
112+
return reflect.DeepEqual(ds[1:], hs[1:])
113+
}
114+
115+
return domainName == tlsHost
116+
}
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
package ls
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/aws/aws-sdk-go/service/acm"
9+
"github.com/kubernetes-sigs/aws-alb-ingress-controller/internal/aws"
10+
"github.com/kubernetes-sigs/aws-alb-ingress-controller/mocks"
11+
"github.com/stretchr/testify/assert"
12+
"k8s.io/apimachinery/pkg/util/sets"
13+
)
14+
15+
type listCertificatesCall struct {
16+
input *acm.ListCertificatesInput
17+
output []*acm.CertificateSummary
18+
err error
19+
}
20+
21+
type describeCertificateCall struct {
22+
certArn string
23+
output *acm.CertificateDetail
24+
err error
25+
}
26+
27+
func Test_CertDiscovery_Discover(t *testing.T) {
28+
for _, tc := range []struct {
29+
name string
30+
hosts []string
31+
listCertificateCall *listCertificatesCall
32+
describeCertificateCalls []describeCertificateCall
33+
expectedCerts []string
34+
expectedErr string
35+
}{
36+
{
37+
name: "when ACM has exact match with TLS host",
38+
hosts: []string{"foo.example.com"},
39+
listCertificateCall: &listCertificatesCall{
40+
input: &acm.ListCertificatesInput{CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued})},
41+
output: []*acm.CertificateSummary{{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/yyy")}},
42+
},
43+
describeCertificateCalls: []describeCertificateCall{
44+
{
45+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/yyy",
46+
output: &acm.CertificateDetail{
47+
SubjectAlternativeNames: aws.StringSlice([]string{"foo.example.com"}),
48+
},
49+
},
50+
},
51+
expectedCerts: []string{"arn:aws:acm:us-west-2:xxx:certificate/yyy"},
52+
},
53+
{
54+
name: "when ACM has wildcard match with TLS host",
55+
hosts: []string{"foo.example.com"},
56+
listCertificateCall: &listCertificatesCall{
57+
input: &acm.ListCertificatesInput{CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued})},
58+
output: []*acm.CertificateSummary{{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/yyy")}},
59+
},
60+
describeCertificateCalls: []describeCertificateCall{
61+
{
62+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/yyy",
63+
output: &acm.CertificateDetail{
64+
SubjectAlternativeNames: aws.StringSlice([]string{"*.example.com"}),
65+
},
66+
},
67+
},
68+
expectedCerts: []string{"arn:aws:acm:us-west-2:xxx:certificate/yyy"},
69+
},
70+
{
71+
name: "when ACM has SAN domain match with TLS host",
72+
hosts: []string{"foo.example.com"},
73+
listCertificateCall: &listCertificatesCall{
74+
input: &acm.ListCertificatesInput{CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued})},
75+
output: []*acm.CertificateSummary{{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/yyy")}},
76+
},
77+
describeCertificateCalls: []describeCertificateCall{
78+
{
79+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/yyy",
80+
output: &acm.CertificateDetail{
81+
SubjectAlternativeNames: aws.StringSlice([]string{"bar.example.com", "foo.example.com"}),
82+
},
83+
},
84+
},
85+
expectedCerts: []string{"arn:aws:acm:us-west-2:xxx:certificate/yyy"},
86+
},
87+
{
88+
name: "when ACM has exact match with multiple TLS host",
89+
hosts: []string{"foo.example.com", "bar.example.com"},
90+
listCertificateCall: &listCertificatesCall{
91+
input: &acm.ListCertificatesInput{CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued})},
92+
output: []*acm.CertificateSummary{
93+
{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/yyy")},
94+
{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/zzz")},
95+
},
96+
},
97+
describeCertificateCalls: []describeCertificateCall{
98+
{
99+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/yyy",
100+
output: &acm.CertificateDetail{
101+
SubjectAlternativeNames: aws.StringSlice([]string{"foo.example.com"}),
102+
},
103+
},
104+
{
105+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/zzz",
106+
output: &acm.CertificateDetail{
107+
SubjectAlternativeNames: aws.StringSlice([]string{"bar.example.com"}),
108+
},
109+
},
110+
},
111+
expectedCerts: []string{"arn:aws:acm:us-west-2:xxx:certificate/yyy", "arn:aws:acm:us-west-2:xxx:certificate/zzz"},
112+
},
113+
{
114+
name: "when ACM has multiple match with TLS host",
115+
hosts: []string{"foo.example.com"},
116+
listCertificateCall: &listCertificatesCall{
117+
input: &acm.ListCertificatesInput{CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued})},
118+
output: []*acm.CertificateSummary{
119+
{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/yyy")},
120+
{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/zzz")},
121+
},
122+
},
123+
describeCertificateCalls: []describeCertificateCall{
124+
{
125+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/yyy",
126+
output: &acm.CertificateDetail{
127+
SubjectAlternativeNames: aws.StringSlice([]string{"foo.example.com"}),
128+
},
129+
},
130+
{
131+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/zzz",
132+
output: &acm.CertificateDetail{
133+
SubjectAlternativeNames: aws.StringSlice([]string{"foo.example.com"}),
134+
},
135+
},
136+
},
137+
expectedCerts: nil,
138+
expectedErr: "multiple certificate found for host: foo.example.com, certARNs: [arn:aws:acm:us-west-2:xxx:certificate/yyy arn:aws:acm:us-west-2:xxx:certificate/zzz]",
139+
},
140+
{
141+
name: "when ACM has no match with TLS host",
142+
hosts: []string{"foo.example.com"},
143+
listCertificateCall: &listCertificatesCall{
144+
input: &acm.ListCertificatesInput{CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusIssued})},
145+
output: []*acm.CertificateSummary{
146+
{CertificateArn: aws.String("arn:aws:acm:us-west-2:xxx:certificate/yyy")},
147+
},
148+
},
149+
describeCertificateCalls: []describeCertificateCall{
150+
{
151+
certArn: "arn:aws:acm:us-west-2:xxx:certificate/yyy",
152+
output: &acm.CertificateDetail{
153+
SubjectAlternativeNames: aws.StringSlice([]string{"bar.example.com"}),
154+
},
155+
},
156+
},
157+
expectedCerts: nil,
158+
expectedErr: "none certificate found for host: foo.example.com",
159+
},
160+
} {
161+
t.Run(tc.name, func(t *testing.T) {
162+
ctx := context.Background()
163+
mockedCloud := &mocks.CloudAPI{}
164+
if tc.listCertificateCall != nil {
165+
mockedCloud.On("ListCertificates", ctx, tc.listCertificateCall.input).Return(tc.listCertificateCall.output, tc.listCertificateCall.err)
166+
}
167+
for _, call := range tc.describeCertificateCalls {
168+
mockedCloud.On("DescribeCertificate", ctx, call.certArn).Return(call.output, call.err)
169+
}
170+
171+
certDiscovery := NewACMCertDiscovery(mockedCloud)
172+
certArns, err := certDiscovery.Discover(ctx, sets.NewString(tc.hosts...))
173+
if tc.expectedErr != "" {
174+
assert.EqualError(t, err, tc.expectedErr)
175+
} else {
176+
assert.Nil(t, err)
177+
}
178+
assert.ElementsMatch(t, certArns, tc.expectedCerts)
179+
})
180+
}
181+
}
182+
183+
func Test_domainMatchesHost(t *testing.T) {
184+
var tests = []struct {
185+
domain string
186+
host string
187+
want bool
188+
}{
189+
{"example.com", "example.com", true},
190+
{"example.com", "exampl0.com", false},
191+
192+
// wildcards
193+
{"*.example.com", "foo.example.com", true},
194+
{"*.example.com", "example.com", false},
195+
{"*.exampl0.com", "foo.example.com", false},
196+
197+
// invalid hosts, not sure these are possible
198+
{"*.*.example.com", "foo.bar.example.com", false},
199+
{"foo.*.example.com", "foo.bar.example.com", false},
200+
}
201+
202+
for _, test := range tests {
203+
var msg = "should"
204+
if !test.want {
205+
msg = "should not"
206+
}
207+
208+
d := &acmCertDiscovery{}
209+
t.Run(fmt.Sprintf("%s %s match %s", test.domain, msg, test.host), func(t *testing.T) {
210+
have := d.domainMatchesHost(test.domain, test.host)
211+
if test.want != have {
212+
t.Fail()
213+
}
214+
})
215+
}
216+
}

0 commit comments

Comments
 (0)