Skip to content

refactor(manager): small refactors around the manager and token logic #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions internal/utils.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package internal

// IsClosed checks if a channel is closed.
// Returns true only if the channel is actually closed, not just if it has data available.
//
// NOTE: It returns true if the channel is closed as well
// as if the channel is not empty. Used internally
// to check if the channel is closed.
// WARNING: This function will consume one value from the channel if it has pending data.
// Use with caution on channels where consuming data might cause issues.
func IsClosed(ch <-chan struct{}) bool {
select {
case <-ch:
return true
case _, ok := <-ch:
// If ok is false, the channel is closed
// If ok is true, the channel had data (which we just consumed)
return !ok
default:
// Channel is open but has no data available
return false
}

return false
}
17 changes: 16 additions & 1 deletion manager/entraid_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) {
e.listener = listener

go func(listener TokenListener, closed <-chan struct{}) {
// Add panic recovery to prevent crashes
defer func() {
if r := recover(); r != nil {
// Attempt to notify listener of panic, but don't panic again if that fails
func() {
defer func() { _ = recover() }()
listener.OnError(fmt.Errorf("token manager goroutine panic: %v", r))
}()
}
}()
maxDelay := e.retryOptions.MaxDelay
initialDelay := e.retryOptions.InitialDelay

Expand Down Expand Up @@ -223,6 +233,7 @@ func (e *entraidTokenManager) stop() (err error) {
err = fmt.Errorf("failed to stop token manager: %s", r)
}
}()

if e.ctxCancel != nil {
e.ctxCancel()
}
Expand All @@ -232,7 +243,11 @@ func (e *entraidTokenManager) stop() (err error) {
}

e.listener = nil
close(e.closedChan)

// Safely close the channel - only close if not already closed
if !internal.IsClosed(e.closedChan) {
close(e.closedChan)
}

return nil
}
Expand Down
10 changes: 6 additions & 4 deletions manager/entraid_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/stretchr/testify/assert"
)

const testDurationDelta = float64(5 * time.Millisecond)

func TestDurationToRenewal(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -236,7 +238,7 @@ func TestDurationToRenewal(t *testing.T) {
}

duration := manager.durationToRenewal(tt.token)
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond),
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), testDurationDelta,
"%s: expected %v, got %v", tt.name, tt.expectedDuration, duration)
})
}
Expand Down Expand Up @@ -415,7 +417,7 @@ func TestDurationToRenewalMillisecondPrecision(t *testing.T) {
}

duration := manager.durationToRenewal(tt.token)
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond),
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), testDurationDelta,
"%s: expected %v, got %v", tt.name, tt.expectedDuration, duration)
})
}
Expand Down Expand Up @@ -453,8 +455,8 @@ func TestDurationToRenewalConcurrent(t *testing.T) {
if i == 0 {
firstResult = result
} else {
// All results should be within 10ms of each other
assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 10)
// All results should be within 5ms of each other
assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 5)
}
}
}
23 changes: 16 additions & 7 deletions token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,22 @@ import (
var _ auth.Credentials = (*Token)(nil)

// New creates a new token with the specified username, password, raw token, expiration time, received at time, and time to live.
// NOTE: This won't do any validation on the token, expiresOn, receivedAt, or ttl. It will simply create a new token instance.
// The caller is responsible for ensuring the token is valid.
// NOTE: The caller is responsible for ensuring the token is valid.
// If the token is invalid, the behavior is undefined.
// - if expiresOn is zero, New returns nil
// - if receivedAt is zero, it will be set to the current time and TTL will be recalculated
// Expiration time and TTL are used to determine when the token should be refreshed.
// TTL is in milliseconds.
// receivedAt + ttl should be within a millisecond of expiresOn
func New(username, password, rawToken string, expiresOn, receivedAt time.Time, ttl int64) *Token {
if expiresOn.IsZero() {
return nil
}
if receivedAt.IsZero() {
receivedAt = time.Now()
ttl = expiresOn.Sub(receivedAt).Milliseconds()
}

return &Token{
username: username,
password: password,
Expand All @@ -28,6 +38,10 @@ func New(username, password, rawToken string, expiresOn, receivedAt time.Time, t

// Token represents parsed authentication token used to access the Redis server.
// It implements the auth.Credentials interface.
//
// WARNING: Use New() to create a new token.
// Creating a token with Token{} is invalid and will undefined behavior in the TokenManager.
// The zero value of Token is not valid.
type Token struct {
// username is the username of the user.
username string
Expand Down Expand Up @@ -60,11 +74,6 @@ func (t *Token) RawToken() string {

// ReceivedAt returns the time when the token was received.
func (t *Token) ReceivedAt() time.Time {
if t.receivedAt.IsZero() {
// set it to now, recalculate ttl
t.receivedAt = time.Now()
t.ttl = t.expiresOn.Sub(t.receivedAt).Milliseconds()
}
return t.receivedAt
}

Expand Down
11 changes: 6 additions & 5 deletions token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,15 @@ func TestCopyToken(t *testing.T) {
assert.NotEqual(t, token.expiresOn, copiedToken.expiresOn)

// copy nil
copiedToken = copyToken(nil)
assert.Nil(t, copiedToken)
nilToken := copyToken(nil)
assert.Nil(t, nilToken)
// copy empty token
copiedToken = copyToken(&Token{})
assert.NotNil(t, copiedToken)
emptyToken := copyToken(&Token{})
assert.Nil(t, emptyToken)
anotherCopy := copiedToken.Copy()
anotherCopy.rawToken = "changed"
assert.NotEqual(t, copiedToken, anotherCopy)
assert.NotEqual(t, copiedToken.rawToken, anotherCopy.rawToken)
}

func TestTokenReceivedAt(t *testing.T) {
Expand All @@ -124,7 +125,7 @@ func TestTokenReceivedAt(t *testing.T) {
// Check if the copied token is a new instance
assert.NotNil(t, tcopiedToken)

emptyRecievedAt := &Token{}
emptyRecievedAt := New("username", "password", "rawToken", time.Now(), time.Time{}, time.Hour.Milliseconds())
assert.True(t, emptyRecievedAt.ReceivedAt().After(time.Now().Add(-1*time.Hour)))
assert.True(t, emptyRecievedAt.ReceivedAt().Before(time.Now().Add(1*time.Hour)))
}
Expand Down
Loading