From 589110961a270733e4e3fb6add75117be124d984 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 5 Mar 2025 14:02:19 +0200
Subject: [PATCH] Isolate redis-entraid dependency for tests (#3521)

Index: tests/conftest.py
--- tests/conftest.py.orig
+++ tests/conftest.py
@@ -4,8 +4,7 @@ import os
 import random
 import time
 from datetime import datetime, timezone
-from enum import Enum
-from typing import Callable, TypeVar, Union
+from typing import Callable, TypeVar
 from unittest import mock
 from unittest.mock import Mock
 from urllib.parse import urlparse
@@ -17,7 +16,6 @@ from packaging.version import Version
 from redis import Sentinel
 from redis.auth.idp import IdentityProviderInterface
 from redis.auth.token import JWToken
-from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
 from redis.backoff import NoBackoff
 from redis.cache import (
     CacheConfig,
@@ -30,22 +28,6 @@ from redis.connection import Connection, ConnectionInt
 from redis.credentials import CredentialProvider
 from redis.exceptions import RedisClusterException
 from redis.retry import Retry
-from redis_entraid.cred_provider import (
-    DEFAULT_DELAY_IN_MS,
-    DEFAULT_EXPIRATION_REFRESH_RATIO,
-    DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
-    DEFAULT_MAX_ATTEMPTS,
-    DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
-    EntraIdCredentialsProvider,
-)
-from redis_entraid.identity_provider import (
-    ManagedIdentityIdType,
-    ManagedIdentityProviderConfig,
-    ManagedIdentityType,
-    ServicePrincipalIdentityProviderConfig,
-    _create_provider_from_managed_identity,
-    _create_provider_from_service_principal,
-)
 from tests.ssl_utils import get_tls_certificates
 
 REDIS_INFO = {}
@@ -61,11 +43,6 @@ _DecoratedTest = TypeVar("_DecoratedTest", bound="Call
 _TestDecorator = Callable[[_DecoratedTest], _DecoratedTest]
 
 
-class AuthType(Enum):
-    MANAGED_IDENTITY = "managed_identity"
-    SERVICE_PRINCIPAL = "service_principal"
-
-
 # Taken from python3.9
 class BooleanOptionalAction(argparse.Action):
     def __init__(
@@ -623,124 +600,18 @@ def mock_identity_provider() -> IdentityProviderInterf
     return mock_provider
 
 
-def identity_provider(request) -> IdentityProviderInterface:
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-    else:
-        kwargs = {}
-
-    if request.param.get("mock_idp", None) is not None:
-        return mock_identity_provider()
-
-    auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
-    config = get_identity_provider_config(request=request)
-
-    if auth_type == "MANAGED_IDENTITY":
-        return _create_provider_from_managed_identity(config)
-
-    return _create_provider_from_service_principal(config)
-
-
-def get_identity_provider_config(
-    request,
-) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-    else:
-        kwargs = {}
-
-    auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
-
-    if auth_type == AuthType.MANAGED_IDENTITY:
-        return _get_managed_identity_provider_config(request)
-
-    return _get_service_principal_provider_config(request)
-
-
-def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
-    resource = os.getenv("AZURE_RESOURCE")
-    id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
-
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-    else:
-        kwargs = {}
-
-    identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
-    id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
-
-    return ManagedIdentityProviderConfig(
-        identity_type=identity_type,
-        resource=resource,
-        id_type=id_type,
-        id_value=id_value,
-        kwargs=kwargs,
-    )
-
-
-def _get_service_principal_provider_config(
-    request,
-) -> ServicePrincipalIdentityProviderConfig:
-    client_id = os.getenv("AZURE_CLIENT_ID")
-    client_credential = os.getenv("AZURE_CLIENT_SECRET")
-    tenant_id = os.getenv("AZURE_TENANT_ID")
-    scopes = os.getenv("AZURE_REDIS_SCOPES", None)
-
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-        token_kwargs = request.param.get("token_kwargs", {})
-        timeout = request.param.get("timeout", None)
-    else:
-        kwargs = {}
-        token_kwargs = {}
-        timeout = None
-
-    if isinstance(scopes, str):
-        scopes = scopes.split(",")
-
-    return ServicePrincipalIdentityProviderConfig(
-        client_id=client_id,
-        client_credential=client_credential,
-        scopes=scopes,
-        timeout=timeout,
-        token_kwargs=token_kwargs,
-        tenant_id=tenant_id,
-        app_kwargs=kwargs,
-    )
-
-
 def get_credential_provider(request) -> CredentialProvider:
     cred_provider_class = request.param.get("cred_provider_class")
     cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
 
-    if cred_provider_class != EntraIdCredentialsProvider:
+    # Since we can't import EntraIdCredentialsProvider in this module,
+    # we'll just check the class name.
+    if cred_provider_class.__name__ != "EntraIdCredentialsProvider":
         return cred_provider_class(**cred_provider_kwargs)
 
-    idp = identity_provider(request)
-    expiration_refresh_ratio = cred_provider_kwargs.get(
-        "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
-    )
-    lower_refresh_bound_millis = cred_provider_kwargs.get(
-        "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
-    )
-    max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
-    delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
+    from tests.entraid_utils import get_entra_id_credentials_provider
 
-    token_mgr_config = TokenManagerConfig(
-        expiration_refresh_ratio=expiration_refresh_ratio,
-        lower_refresh_bound_millis=lower_refresh_bound_millis,
-        token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,  # noqa
-        retry_policy=RetryPolicy(
-            max_attempts=max_attempts,
-            delay_in_ms=delay_in_ms,
-        ),
-    )
-
-    return EntraIdCredentialsProvider(
-        identity_provider=idp,
-        token_manager_config=token_mgr_config,
-        initial_delay_in_ms=delay_in_ms,
-    )
+    return get_entra_id_credentials_provider(request, cred_provider_kwargs)
 
 
 @pytest.fixture()
