Skip to content

Add API to get managed identity source #823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ protected String getMessageFromErrorResponse(IHttpResponse response) {
managedIdentityErrorResponse.getError(), managedIdentityErrorResponse.getErrorDescription());
}

protected static IEnvironmentVariables getEnvironmentVariables(ManagedIdentityParameters parameters) {
return parameters.environmentVariables == null ? new EnvironmentVariables() : parameters.environmentVariables;
protected static IEnvironmentVariables getEnvironmentVariables() {
return ManagedIdentityApplication.environmentVariables == null ?
new EnvironmentVariables() : ManagedIdentityApplication.environmentVariables;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String msiSecret = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
{
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
super(msalRequest, serviceBundle, ManagedIdentitySourceType.IMDS);
ManagedIdentityParameters parameters = (ManagedIdentityParameters) msalRequest.requestContext().apiParameters();
IEnvironmentVariables environmentVariables = ((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()).environmentVariables == null ?
new EnvironmentVariables() :
parameters.environmentVariables;
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))){
LOG.info(String.format("[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: %s", environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST)));
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

package com.microsoft.aad.msal4j;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;
Expand All @@ -22,6 +24,16 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen
@Getter
static TokenCache sharedTokenCache = new TokenCache();

@Getter(value = AccessLevel.PUBLIC)
static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();

@Getter(value = AccessLevel.PACKAGE)
static IEnvironmentVariables environmentVariables;

static void setEnvironmentVariables(IEnvironmentVariables environmentVariables) {
ManagedIdentityApplication.environmentVariables = environmentVariables;
}

private ManagedIdentityApplication(Builder builder) {
super(builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package com.microsoft.aad.msal4j;

import lombok.AccessLevel;
import lombok.Getter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -12,6 +14,37 @@
class ManagedIdentityClient {
private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class);

private static ManagedIdentitySourceType managedIdentitySourceType;

protected static void resetManagedIdentitySourceType() {
managedIdentitySourceType = ManagedIdentitySourceType.NONE;
}

static ManagedIdentitySourceType getManagedIdentitySource() {
if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) {
return managedIdentitySourceType;
}

IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables();

if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) {
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC;
} else
managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE;
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL;
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC;
} else {
managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS;
}

return managedIdentitySourceType;
}

AbstractManagedIdentitySource managedIdentitySource;

ManagedIdentityClient(MsalRequest msalRequest, ServiceBundle serviceBundle) {
Expand All @@ -38,16 +71,22 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
AbstractManagedIdentitySource managedIdentitySource;
if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);

if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) {
managedIdentitySourceType = getManagedIdentitySource();
}

switch (managedIdentitySourceType) {
case SERVICE_FABRIC:
return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle);
case APP_SERVICE:
return AppServiceManagedIdentitySource.create(msalRequest, serviceBundle);
case CLOUD_SHELL:
return CloudShellManagedIdentitySource.create(msalRequest, serviceBundle);
case AZURE_ARC:
return AzureArcManagedIdentitySource.create(msalRequest, serviceBundle);
default:
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters {

boolean forceRefresh;

IEnvironmentVariables environmentVariables;

@Override
public Set<String> scopes() {
return null;
Expand All @@ -54,10 +52,6 @@ public Map<String, String> extraQueryParameters() {
return null;
}

void setEnvironmentVariablesConfig(IEnvironmentVariables environmentVariables) {
this.environmentVariables = environmentVariables;
}

private static ManagedIdentityParametersBuilder builder() {
return new ManagedIdentityParametersBuilder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
enum ManagedIdentitySourceType {
// Default.
NONE,
// The source to acquire token for managed identity is IMDS.
// The source used to acquire token for managed identity is IMDS.
IMDS,
// The source to acquire token for managed identity is App Service.
// The source used to acquire token for managed identity is App Service.
APP_SERVICE,
// The source to acquire token for managed identity is Azure Arc.
// The source used to acquire token for managed identity is Azure Arc.
AZURE_ARC,
// The source to acquire token for managed identity is Cloud Shell.
// The source used to acquire token for managed identity is Cloud Shell.
CLOUD_SHELL,
// The source to acquire token for managed identity is Service Fabric.
SERVICE_FABRIC
// The source used to acquire token for managed identity is Service Fabric.
SERVICE_FABRIC,
// The source to acquire token for managed identity is defaulted to IMDS when no environment variables are set.
DEFAULT_TO_IMDS
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ public ManagedIdentityResponse getManagedIdentityResponse(

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);


if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
{
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,14 @@ public static Stream<Arguments> createDataError() {
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint));
}

public static Stream<Arguments> createDataGetSource() {
return Stream.of(
Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, ManagedIdentitySourceType.AZURE_ARC),
Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, ManagedIdentitySourceType.APP_SERVICE),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, ManagedIdentitySourceType.CLOUD_SHELL),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS),
Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC));
}
}
Loading