From 8bb9c7586dd7d978a8f7cc12f04d650e11d20944 Mon Sep 17 00:00:00 2001 From: Aijing Zeng Date: Sat, 8 Feb 2025 14:54:05 +0800 Subject: [PATCH] Add disable-instance-discovery option in interactive pop mode --- docs/book/src/cli/get-token.md | 1 + pkg/internal/pop/msal_public.go | 43 +++++++++++---------- pkg/internal/pop/msal_public_test.go | 52 ++++++++++++++------------ pkg/internal/token/interactive.go | 40 +++++++++++--------- pkg/internal/token/interactive_test.go | 1 + pkg/internal/token/options.go | 2 + pkg/internal/token/provider.go | 4 +- pkg/internal/token/ropc.go | 40 +++++++++++--------- pkg/internal/token/ropc_test.go | 1 + 9 files changed, 106 insertions(+), 78 deletions(-) diff --git a/docs/book/src/cli/get-token.md b/docs/book/src/cli/get-token.md index 4b94bdc8..cc5379bb 100644 --- a/docs/book/src/cli/get-token.md +++ b/docs/book/src/cli/get-token.md @@ -20,6 +20,7 @@ ZURE_CLIENT_CERTIFICATE_PASSWORD environment variable --client-id string AAD client application ID. It may be specified in AAD_SERVICE_PRINCIPAL_CLIENT_ID or AZURE_CLIENT_ID environment variable --client-secret string AAD client application secret. Used in spn login. It may be specified in AAD_SERVICE_PRINCIPAL_CLIENT_SECRET or AZURE_CLIENT_S ECRET environment variable + --disable-instance-discovery set to true to disable instance discovery in environments with their own Identity Provider (not Entra ID/AAD) that does not have instance metadata discovery endpoint. -e, --environment string Azure environment name (default "AzurePublicCloud") --federated-token-file string Workload Identity federated token file. It may be specified in AZURE_FEDERATED_TOKEN_FILE environment variable -h, --help help for get-token diff --git a/pkg/internal/pop/msal_public.go b/pkg/internal/pop/msal_public.go index 5f9c8366..c3e242c8 100644 --- a/pkg/internal/pop/msal_public.go +++ b/pkg/internal/pop/msal_public.go @@ -9,19 +9,24 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) +type PublicClientOptions struct { + Authority string + ClientID string + DisableInstanceDiscovery bool + Options *azcore.ClientOptions +} + // AcquirePoPTokenInteractive acquires a PoP token using MSAL's interactive login flow. // Requires user to authenticate via browser func AcquirePoPTokenInteractive( context context.Context, popClaims map[string]string, scopes []string, - authority, - clientID string, - options *azcore.ClientOptions, + pcOptions *PublicClientOptions, ) (string, int64, error) { var client *public.Client var err error - client, err = getPublicClient(authority, clientID, options) + client, err = getPublicClient(pcOptions) if err != nil { return "", -1, err } @@ -53,13 +58,11 @@ func AcquirePoPTokenByUsernamePassword( context context.Context, popClaims map[string]string, scopes []string, - authority, - clientID, username, password string, - options *azcore.ClientOptions, + pcOptions *PublicClientOptions, ) (string, int64, error) { - client, err := getPublicClient(authority, clientID, options) + client, err := getPublicClient(pcOptions) if err != nil { return "", -1, err } @@ -88,23 +91,25 @@ func AcquirePoPTokenByUsernamePassword( } // getPublicClient returns an instance of the msal `public` client based on the provided options -func getPublicClient( - authority, - clientID string, - options *azcore.ClientOptions, -) (*public.Client, error) { +// The instance discovery will be disable on private cloud +func getPublicClient(pcOptions *PublicClientOptions) (*public.Client, error) { var client public.Client var err error - if options != nil && options.Transport != nil { + if pcOptions == nil { + return nil, fmt.Errorf("unable to create public client: publicClientOptions is empty") + } + if pcOptions.Options != nil && pcOptions.Options.Transport != nil { client, err = public.New( - clientID, - public.WithAuthority(authority), - public.WithHTTPClient(options.Transport.(*http.Client)), + pcOptions.ClientID, + public.WithAuthority(pcOptions.Authority), + public.WithHTTPClient(pcOptions.Options.Transport.(*http.Client)), + public.WithInstanceDiscovery(!pcOptions.DisableInstanceDiscovery), ) } else { client, err = public.New( - clientID, - public.WithAuthority(authority), + pcOptions.ClientID, + public.WithAuthority(pcOptions.Authority), + public.WithInstanceDiscovery(!pcOptions.DisableInstanceDiscovery), ) } if err != nil { diff --git a/pkg/internal/pop/msal_public_test.go b/pkg/internal/pop/msal_public_test.go index 3524ba95..4d6354e9 100644 --- a/pkg/internal/pop/msal_public_test.go +++ b/pkg/internal/pop/msal_public_test.go @@ -118,11 +118,13 @@ func TestAcquirePoPTokenByUsernamePassword(t *testing.T) { ctx, tc.p.popClaims, scopes, - authority, - tc.p.clientID, tc.p.username, tc.p.password, - &clientOpts, + &PublicClientOptions{ + Authority: authority, + ClientID: tc.p.clientID, + Options: &clientOpts, + }, ) defer vcrRecorder.Stop() if tc.expectedError != nil { @@ -156,35 +158,43 @@ func TestGetPublicClient(t *testing.T) { testCase := []struct { testName string - authority string - options *azcore.ClientOptions + pcOptions *PublicClientOptions expectedError error }{ { // Test using custom HTTP transport - testName: "TestGetPublicClientWithCustomTransport", - authority: authority, - options: &azcore.ClientOptions{ - Cloud: cloud.AzurePublic, - Transport: httpClient, + testName: "TestGetPublicClientWithCustomTransport", + pcOptions: &PublicClientOptions{ + Authority: authority, + ClientID: testutils.ClientID, + Options: &azcore.ClientOptions{ + Cloud: cloud.AzurePublic, + Transport: httpClient, + }, }, expectedError: nil, }, { // Test using default HTTP transport - testName: "TestGetPublicClientWithDefaultTransport", - authority: authority, - options: &azcore.ClientOptions{ - Cloud: cloud.AzurePublic, + testName: "TestGetPublicClientWithDefaultTransport", + pcOptions: &PublicClientOptions{ + Authority: authority, + ClientID: testutils.ClientID, + Options: &azcore.ClientOptions{ + Cloud: cloud.AzurePublic, + }, }, expectedError: nil, }, { // Test using incorrectly formatted authority - testName: "TestGetPublicClientWithBadAuthority", - authority: "login.microsoft.com", - options: &azcore.ClientOptions{ - Cloud: cloud.AzurePublic, + testName: "TestGetPublicClientWithBadAuthority", + pcOptions: &PublicClientOptions{ + Authority: "login.microsoft.com", + ClientID: testutils.ClientID, + Options: &azcore.ClientOptions{ + Cloud: cloud.AzurePublic, + }, }, expectedError: fmt.Errorf("unable to create public client"), }, @@ -195,11 +205,7 @@ func TestGetPublicClient(t *testing.T) { for _, tc := range testCase { t.Run(tc.testName, func(t *testing.T) { - client, err = getPublicClient( - tc.authority, - testutils.ClientID, - tc.options, - ) + client, err = getPublicClient(tc.pcOptions) if tc.expectedError != nil { if !testutils.ErrorContains(err, tc.expectedError.Error()) { diff --git a/pkg/internal/token/interactive.go b/pkg/internal/token/interactive.go index a3d2bebd..a20b71c3 100644 --- a/pkg/internal/token/interactive.go +++ b/pkg/internal/token/interactive.go @@ -16,16 +16,17 @@ import ( ) type InteractiveToken struct { - clientID string - resourceID string - tenantID string - oAuthConfig adal.OAuthConfig - popClaims map[string]string + clientID string + resourceID string + tenantID string + oAuthConfig adal.OAuthConfig + popClaims map[string]string + disableInstanceDiscovery bool } // newInteractiveTokenProvider returns a TokenProvider that will fetch a token for the user currently logged into the Interactive. // Required arguments include an oAuthConfiguration object and the resourceID (which is used as the scope) -func newInteractiveTokenProvider(oAuthConfig adal.OAuthConfig, clientID, resourceID, tenantID string, popClaims map[string]string) (TokenProvider, error) { +func newInteractiveTokenProvider(oAuthConfig adal.OAuthConfig, clientID, resourceID, tenantID string, popClaims map[string]string, disableInstanceDiscovery bool) (TokenProvider, error) { if clientID == "" { return nil, errors.New("clientID cannot be empty") } @@ -37,11 +38,12 @@ func newInteractiveTokenProvider(oAuthConfig adal.OAuthConfig, clientID, resourc } return &InteractiveToken{ - clientID: clientID, - resourceID: resourceID, - tenantID: tenantID, - oAuthConfig: oAuthConfig, - popClaims: popClaims, + clientID: clientID, + resourceID: resourceID, + tenantID: tenantID, + oAuthConfig: oAuthConfig, + popClaims: popClaims, + disableInstanceDiscovery: disableInstanceDiscovery, }, nil } @@ -73,18 +75,22 @@ func (p *InteractiveToken) TokenWithOptions(ctx context.Context, options *azcore ctx, p.popClaims, scopes, - authorityFromConfig.String(), - p.clientID, - &clientOpts, + &pop.PublicClientOptions{ + Authority: authorityFromConfig.String(), + ClientID: p.clientID, + DisableInstanceDiscovery: p.disableInstanceDiscovery, + Options: &clientOpts, + }, ) if err != nil { return emptyToken, fmt.Errorf("failed to create PoP token using interactive login: %w", err) } } else { cred, err := azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ - ClientOptions: clientOpts, - TenantID: p.tenantID, - ClientID: p.clientID, + ClientOptions: clientOpts, + TenantID: p.tenantID, + ClientID: p.clientID, + DisableInstanceDiscovery: p.disableInstanceDiscovery, }) if err != nil { return emptyToken, fmt.Errorf("unable to create credential. Received: %w", err) diff --git a/pkg/internal/token/interactive_test.go b/pkg/internal/token/interactive_test.go index aaf2f32c..ec5c871b 100644 --- a/pkg/internal/token/interactive_test.go +++ b/pkg/internal/token/interactive_test.go @@ -59,6 +59,7 @@ func TestNewInteractiveToken(t *testing.T) { tc.resourceID, tc.tenantID, tc.popClaims, + false, ) if tc.expectedError != "" { diff --git a/pkg/internal/token/options.go b/pkg/internal/token/options.go index 872c25c5..98d86c18 100644 --- a/pkg/internal/token/options.go +++ b/pkg/internal/token/options.go @@ -35,6 +35,7 @@ type Options struct { IsPoPTokenEnabled bool PoPTokenClaims string DisableEnvironmentOverride bool + DisableInstanceDiscovery bool } const ( @@ -110,6 +111,7 @@ func (o *Options) AddFlags(fs *pflag.FlagSet) { fmt.Sprintf("Timeout duration for Azure CLI token requests. It may be specified in %s environment variable", "AZURE_CLI_TIMEOUT")) fs.StringVar(&o.PoPTokenClaims, "pop-claims", o.PoPTokenClaims, "contains a comma-separated list of claims to attach to the pop token in the format `key=val,key2=val2`. At minimum, specify the ARM ID of the cluster as `u=ARM_ID`") fs.BoolVar(&o.DisableEnvironmentOverride, "disable-environment-override", o.DisableEnvironmentOverride, "Enable or disable the use of env-variables. Default false") + fs.BoolVar(&o.DisableInstanceDiscovery, "disable-instance-discovery", o.DisableInstanceDiscovery, "set to true to disable instance discovery in environments with their own simple Identity Provider (not AAD) that do not have instance metadata discovery endpoint. Default false") } func (o *Options) Validate() error { diff --git a/pkg/internal/token/provider.go b/pkg/internal/token/provider.go index 1f1b257b..23adb5f6 100644 --- a/pkg/internal/token/provider.go +++ b/pkg/internal/token/provider.go @@ -34,14 +34,14 @@ func NewTokenProvider(o *Options) (TokenProvider, error) { case DeviceCodeLogin: return newDeviceCodeTokenProvider(*oAuthConfig, o.ClientID, o.ServerID, o.TenantID) case InteractiveLogin: - return newInteractiveTokenProvider(*oAuthConfig, o.ClientID, o.ServerID, o.TenantID, popClaimsMap) + return newInteractiveTokenProvider(*oAuthConfig, o.ClientID, o.ServerID, o.TenantID, popClaimsMap, o.DisableInstanceDiscovery) case ServicePrincipalLogin: if o.IsLegacy { return newLegacyServicePrincipalToken(*oAuthConfig, o.ClientID, o.ClientSecret, o.ClientCert, o.ClientCertPassword, o.ServerID, o.TenantID) } return newServicePrincipalTokenProvider(cloudConfiguration, o.ClientID, o.ClientSecret, o.ClientCert, o.ClientCertPassword, o.ServerID, o.TenantID, false, popClaimsMap) case ROPCLogin: - return newResourceOwnerTokenProvider(*oAuthConfig, o.ClientID, o.Username, o.Password, o.ServerID, o.TenantID, popClaimsMap) + return newResourceOwnerTokenProvider(*oAuthConfig, o.ClientID, o.Username, o.Password, o.ServerID, o.TenantID, popClaimsMap, o.DisableInstanceDiscovery) case MSILogin: return newManagedIdentityToken(o.ClientID, o.IdentityResourceID, o.ServerID) case AzureCLILogin: diff --git a/pkg/internal/token/ropc.go b/pkg/internal/token/ropc.go index e719ec1f..e4462109 100644 --- a/pkg/internal/token/ropc.go +++ b/pkg/internal/token/ropc.go @@ -14,13 +14,14 @@ import ( ) type resourceOwnerToken struct { - clientID string - username string - password string - resourceID string - tenantID string - oAuthConfig adal.OAuthConfig - popClaims map[string]string + clientID string + username string + password string + resourceID string + tenantID string + oAuthConfig adal.OAuthConfig + popClaims map[string]string + disableInstanceDiscovery bool } func newResourceOwnerTokenProvider( @@ -31,6 +32,7 @@ func newResourceOwnerTokenProvider( resourceID, tenantID string, popClaims map[string]string, + disableInstanceDiscovery bool, ) (TokenProvider, error) { if clientID == "" { return nil, errors.New("clientID cannot be empty") @@ -49,13 +51,14 @@ func newResourceOwnerTokenProvider( } return &resourceOwnerToken{ - clientID: clientID, - username: username, - password: password, - resourceID: resourceID, - tenantID: tenantID, - oAuthConfig: oAuthConfig, - popClaims: popClaims, + clientID: clientID, + username: username, + password: password, + resourceID: resourceID, + tenantID: tenantID, + oAuthConfig: oAuthConfig, + popClaims: popClaims, + disableInstanceDiscovery: disableInstanceDiscovery, }, nil } @@ -82,11 +85,14 @@ func (p *resourceOwnerToken) tokenWithOptions(ctx context.Context, options *azco ctx, p.popClaims, scopes, - authorityFromConfig.String(), - p.clientID, p.username, p.password, - &clientOpts, + &pop.PublicClientOptions{ + Authority: authorityFromConfig.String(), + ClientID: p.clientID, + DisableInstanceDiscovery: p.disableInstanceDiscovery, + Options: &clientOpts, + }, ) if err != nil { return emptyToken, fmt.Errorf("failed to create PoP token using resource owner flow: %w", err) diff --git a/pkg/internal/token/ropc_test.go b/pkg/internal/token/ropc_test.go index 5978cad7..e7a741a8 100644 --- a/pkg/internal/token/ropc_test.go +++ b/pkg/internal/token/ropc_test.go @@ -91,6 +91,7 @@ func TestNewResourceOwnerTokenProvider(t *testing.T) { tc.resourceID, tc.tenantID, tc.popClaims, + false, ) if tc.expectedError != "" {