Skip to content

Enhanced exceptions hierarchy for jwt-key resolution #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ public DefaultJwtAuthenticator(JwtKeyResolver jwtKeyResolver) {

@Override
public Mono<Profile> authenticate(String token) {
return Mono.defer(
() -> {
String tokenWithoutSignature = token.substring(0, token.lastIndexOf(".") + 1);
return Mono.defer(() -> authenticate0(token)).onErrorMap(AuthenticationException::new);
}

private Mono<Profile> authenticate0(String token) {
String tokenWithoutSignature = token.substring(0, token.lastIndexOf(".") + 1);

JwtParser parser = Jwts.parser();
JwtParser parser = Jwts.parser();

Jwt<Header, Claims> claims = parser.parseClaimsJwt(tokenWithoutSignature);
Jwt<Header, Claims> claims = parser.parseClaimsJwt(tokenWithoutSignature);

return jwtKeyResolver
.resolve((Map<String, Object>) claims.getHeader())
.map(key -> parser.setSigningKey(key).parseClaimsJws(token).getBody())
.map(this::profileFromClaims);
})
.onErrorMap(AuthenticationException::new);
return jwtKeyResolver
.resolve((Map<String, Object>) claims.getHeader())
.map(key -> parser.setSigningKey(key).parseClaimsJws(token).getBody())
.map(this::profileFromClaims);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public interface JwtAuthenticator extends Authenticator {

/**
* Create a profile from claims.
*
* @param tokenClaims the claims to parse
* @return a profile from the claims
*/
Expand Down
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@
<version>${hamcrest.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-test</artifactId>
Expand Down
20 changes: 4 additions & 16 deletions tokens/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
<artifactId>scalecube-security-tokens</artifactId>

<dependencies>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
Expand All @@ -23,21 +27,11 @@
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<!-- Tests -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>vault</artifactId>
Expand All @@ -50,12 +44,6 @@
<version>${vault-java-driver.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
Expand All @@ -26,6 +27,9 @@ public final class JwksKeyProvider implements KeyProvider {

private static final Logger LOGGER = LoggerFactory.getLogger(JwksKeyProvider.class);

private static final Duration CONNECT_TIMEOUT = Duration.ofSeconds(10);
private static final Duration READ_TIMEOUT = Duration.ofSeconds(10);

private static final ObjectMapper OBJECT_MAPPER = newObjectMapper();

private final Scheduler scheduler;
Expand All @@ -39,7 +43,7 @@ public final class JwksKeyProvider implements KeyProvider {
* @param jwksUri jwksUri
*/
public JwksKeyProvider(String jwksUri) {
this(jwksUri, newScheduler(), Duration.ofSeconds(10), Duration.ofSeconds(10));
this(jwksUri, newScheduler(), CONNECT_TIMEOUT, READ_TIMEOUT);
}

/**
Expand All @@ -60,38 +64,38 @@ public JwksKeyProvider(

@Override
public Mono<Key> findKey(String kid) {
return Mono.defer(this::callJwksUri)
.map(this::toKeyList)
.flatMap(list -> Mono.justOrEmpty(findRsaKey(list, kid)))
.switchIfEmpty(Mono.error(new KeyProviderException("Key was not found, kid: " + kid)))
return computeKey(kid)
.switchIfEmpty(Mono.error(new KeyNotFoundException("Key was not found, kid: " + kid)))
.doOnSubscribe(s -> LOGGER.debug("[findKey] Looking up key in jwks, kid: {}", kid))
.subscribeOn(scheduler)
.publishOn(scheduler);
.subscribeOn(scheduler);
}

private Mono<Key> computeKey(String kid) {
return Mono.fromCallable(this::computeKeyList)
.flatMap(list -> Mono.justOrEmpty(findRsaKey(list, kid)))
.onErrorMap(th -> th instanceof KeyProviderException ? th : new KeyProviderException(th));
}

private Mono<InputStream> callJwksUri() {
return Mono.fromCallable(
() -> {
HttpURLConnection httpClient = (HttpURLConnection) new URL(jwksUri).openConnection();
httpClient.setConnectTimeout((int) connectTimeoutMillis);
httpClient.setReadTimeout((int) readTimeoutMillis);

int responseCode = httpClient.getResponseCode();
if (responseCode != 200) {
LOGGER.error("[callJwksUri][{}] Not expected response code: {}", jwksUri, responseCode);
throw new KeyProviderException("Not expected response code: " + responseCode);
}

return httpClient.getInputStream();
});
private JwkInfoList computeKeyList() throws IOException {
HttpURLConnection httpClient = (HttpURLConnection) new URL(jwksUri).openConnection();
httpClient.setConnectTimeout((int) connectTimeoutMillis);
httpClient.setReadTimeout((int) readTimeoutMillis);

int responseCode = httpClient.getResponseCode();
if (responseCode != 200) {
LOGGER.error("[computeKey][{}] Not expected response code: {}", jwksUri, responseCode);
throw new KeyProviderException("Not expected response code: " + responseCode);
}

return toKeyList(httpClient.getInputStream());
}

private JwkInfoList toKeyList(InputStream stream) {
private static JwkInfoList toKeyList(InputStream stream) {
try (InputStream inputStream = new BufferedInputStream(stream)) {
return OBJECT_MAPPER.readValue(inputStream, JwkInfoList.class);
} catch (IOException e) {
LOGGER.error("[toKeyList] Exception occurred: {}", e.toString());
throw new KeyProviderException(e);
throw Exceptions.propagate(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public final class JwtTokenResolverImpl implements JwtTokenResolver {

private static final Logger LOGGER = LoggerFactory.getLogger(JwtTokenResolver.class);

private static final Duration CLEANUP_INTERVAL = Duration.ofSeconds(60);

private final KeyProvider keyProvider;
private final JwtTokenParserFactory tokenParserFactory;
private final Scheduler scheduler;
Expand All @@ -31,7 +33,7 @@ public final class JwtTokenResolverImpl implements JwtTokenResolver {
* @param keyProvider key provider
*/
public JwtTokenResolverImpl(KeyProvider keyProvider) {
this(keyProvider, new JsonwebtokenParserFactory(), newScheduler(), Duration.ofSeconds(60));
this(keyProvider, new JsonwebtokenParserFactory(), newScheduler(), CLEANUP_INTERVAL);
}

/**
Expand All @@ -49,8 +51,8 @@ public JwtTokenResolverImpl(
Duration cleanupInterval) {
this.keyProvider = keyProvider;
this.tokenParserFactory = tokenParserFactory;
this.cleanupInterval = cleanupInterval;
this.scheduler = scheduler;
this.cleanupInterval = cleanupInterval;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.scalecube.security.tokens.jwt;

public final class KeyNotFoundException extends RuntimeException {

public KeyNotFoundException(String s) {
super(s);
}

@Override
public synchronized Throwable fillInStackTrace() {
return this;
}

@Override
public String toString() {
return getClass().getSimpleName() + "{errorMessage=" + getMessage() + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@

public final class KeyProviderException extends RuntimeException {

public KeyProviderException() {}

public KeyProviderException(String s) {
super(s);
}

public KeyProviderException(String s, Throwable throwable) {
super(s, throwable);
}

public KeyProviderException(Throwable throwable) {
super(throwable);
public KeyProviderException(Throwable cause) {
super(cause);
}

@Override
Expand Down
17 changes: 2 additions & 15 deletions tokens/src/main/java/io/scalecube/security/tokens/jwt/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@ private Utils() {
// Do not instantiate
}

/**
* Turns b64 url encoded {@code n} and {@code e} into RSA public key.
*
* @param n modulus (b64 url encoded)
* @param e exponent (b64 url encoded)
* @return RSA public key instance
*/
public static Key toRsaPublicKey(String n, String e) {
static Key toRsaPublicKey(String n, String e) {
Decoder b64Decoder = Base64.getUrlDecoder();
BigInteger modulus = new BigInteger(1, b64Decoder.decode(n));
BigInteger exponent = new BigInteger(1, b64Decoder.decode(e));
Expand All @@ -34,13 +27,7 @@ public static Key toRsaPublicKey(String n, String e) {
}
}

/**
* Mask sensitive data by replacing part of string with an asterisk symbol.
*
* @param data sensitive data to be masked
* @return masked data
*/
public static String mask(String data) {
static String mask(String data) {
if (data == null || data.isEmpty() || data.length() < 5) {
return "*****";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package io.scalecube.security.tokens.jwt;

import static io.scalecube.security.tokens.jwt.Utils.toRsaPublicKey;

import java.security.Key;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.Properties;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
Expand All @@ -20,7 +16,7 @@ class JwtTokenResolverTests {

@Test
void testTokenResolver() throws Exception {
TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties");
JwtTokenWithKey tokenWithKey = new JwtTokenWithKey("token-and-pubkey.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Mockito.when(tokenParser.parseToken())
Expand Down Expand Up @@ -51,9 +47,9 @@ void testTokenResolver() throws Exception {

@Test
void testTokenResolverWithRotatingKey() throws Exception {
TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties");
TokenWithKey tokenWithKeyAfterRotation =
new TokenWithKey("token-and-pubkey.after-rotation.properties");
JwtTokenWithKey tokenWithKey = new JwtTokenWithKey("token-and-pubkey.properties");
JwtTokenWithKey tokenWithKeyAfterRotation =
new JwtTokenWithKey("token-and-pubkey.after-rotation.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Mockito.when(tokenParser.parseToken())
Expand Down Expand Up @@ -98,7 +94,7 @@ void testTokenResolverWithRotatingKey() throws Exception {

@Test
void testTokenResolverWithWrongKey() throws Exception {
TokenWithKey tokenWithWrongKey = new TokenWithKey("token-and-wrong-pubkey.properties");
JwtTokenWithKey tokenWithWrongKey = new JwtTokenWithKey("token-and-wrong-pubkey.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Mockito.when(tokenParser.parseToken())
Expand Down Expand Up @@ -128,7 +124,7 @@ void testTokenResolverWithWrongKey() throws Exception {

@Test
void testTokenResolverWhenKeyProviderFailing() throws Exception {
TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties");
JwtTokenWithKey tokenWithKey = new JwtTokenWithKey("token-and-pubkey.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Mockito.when(tokenParser.parseToken())
Expand All @@ -153,20 +149,4 @@ void testTokenResolverWhenKeyProviderFailing() throws Exception {
// failed resolution not stored => keyProvider must have been called 2 times
Mockito.verify(keyProvider, Mockito.times(2)).findKey(tokenWithKey.kid);
}

static class TokenWithKey {

final String token;
final Key key;
final String kid;

TokenWithKey(String s) throws Exception {
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Properties props = new Properties();
props.load(classLoader.getResourceAsStream(s));
this.token = props.getProperty("token");
this.kid = props.getProperty("kid");
this.key = toRsaPublicKey(props.getProperty("n"), props.getProperty("e"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.scalecube.security.tokens.jwt;

import static io.scalecube.security.tokens.jwt.Utils.toRsaPublicKey;

import java.security.Key;
import java.util.Properties;

class JwtTokenWithKey {

final String token;
final Key key;
final String kid;

JwtTokenWithKey(String s) throws Exception {
Properties props = new Properties();
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
props.load(classLoader.getResourceAsStream(s));
this.token = props.getProperty("token");
this.kid = props.getProperty("kid");
this.key = toRsaPublicKey(props.getProperty("n"), props.getProperty("e"));
}
}
Loading