Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

126 airflow cognito integration #226

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
134 changes: 134 additions & 0 deletions airflow/config/webserver_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import base64
import json
import logging
import os
from typing import Any

from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride
from authlib.common.urls import url_decode
from authlib.integrations.base_client.sync_openid import OpenIDMixin
from authlib.oauth2.client import OAuth2Client
from flask_appbuilder.security.manager import AUTH_OAUTH
from tornado.httpclient import HTTPClient, HTTPRequest
from tornado.httputil import url_concat

# Logging
log = logging.getLogger("flask_appbuilder.security.views")

# Cognito integration data
COGNITO_BASE_URL = os.environ["COGNITO_BASE_URL"]
COGNITO_CLIENT_ID = os.environ["COGNITO_CLIENT_ID"]
COGNITO_CLIENT_SECRET = os.environ["COGNITO_CLIENT_SECRET"]
COGNITO_USER_POOL_ID = os.environ["COGNITO_USER_POOL_ID"]

# Authentication constants
AUTH_TYPE = AUTH_OAUTH
AUTH_USER_REGISTRATION = True # allow users not in the FAB DB
AUTH_USER_REGISTRATION_ROLE = "Admin" # role given in addition to AUTH_ROLES
AUTH_ROLES_SYNC_AT_LOGIN = True # replace all user's roles each login
AUTH_ROLES_MAPPING = { # mapping of Cognito groups to FAB roles
"Unity_Viewer": "User",
"Unity_Admin": "Admin",
}

# Cognito provider data
OAUTH_PROVIDERS = [
{
"name": "Cognito",
"icon": "fa-amazon",
"token_key": "access_token",
"remote_app": {
"client_id": COGNITO_CLIENT_ID,
"client_secret": COGNITO_CLIENT_SECRET,
"api_base_url": f"{COGNITO_BASE_URL}/",
"client_kwargs": {"scope": "email openid profile"},
"access_token_url": f"{COGNITO_BASE_URL}/token",
"authorize_url": f"{COGNITO_BASE_URL}/authorize",
"jwks_uri": f"https://cognito-idp.us-west-2.amazonaws.com/{COGNITO_USER_POOL_ID}/.well-known/jwks.json",
},
}
]


def fetch_token(self, url, body="", headers=None, auth=None, method="POST", state=None, **kwargs):
"""Overridden method to fetch Cognito token data."""

# Encode client Id and secret
message = auth.client_id + ":" + auth.client_secret
message_bytes = message.encode("ascii")
base64_bytes = base64.b64encode(message_bytes)
base64_auth = base64_bytes.decode("ascii")

# Build URL with parameters
body_dict = dict(url_decode(body))
params = dict(
client_id=auth.client_id,
code=body_dict["code"],
grant_type="authorization_code",
redirect_uri=body_dict["redirect_uri"],
)
url = url_concat(url, params)
req = HTTPRequest(
url,
method="POST",
headers={
"Accept": "application/json",
"Authorization": "Basic " + base64_auth,
"Content-Type": "application/x-www-form-urlencoded",
},
body="",
)

# POST request to Cognito for token data
http_client = HTTPClient()
resp = http_client.fetch(req)
resp_json = json.loads(resp.body.decode("utf8", "replace"))

return resp_json


def fetch_jwk(self, force=False):
"""Fetch JWK public data."""

metadata = self.load_server_metadata()
jwks_uri = metadata.get("jwks_uri")
log.debug("jwks_uri: %s", jwks_uri)

req = HTTPRequest(jwks_uri, method="GET")
http_client = HTTPClient()
jwks_response = http_client.fetch(req)
jwks_json = json.loads(jwks_response.body.decode("utf8", "replace"))
log.debug("jwks_json: %s", jwks_json)
return jwks_json


def map_roles(roles):
"""Map Cognito roles to Airflow roles."""

return list(set(AUTH_ROLES_MAPPING.get(role, "Public") for role in roles))


# Security manager override
class CognitoAuthorizer(FabAirflowSecurityManagerOverride):

def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]:
"""Override method to login with Cognito specific data."""

if provider == "Cognito":
user_info = resp["userinfo"]
log.debug("user_info: %s", user_info)

roles = map_roles(user_info["cognito:groups"])
log.debug("roles: %s", roles)

return {
"username": user_info["cognito:username"],
"email": user_info["email"],
"role_keys": roles,
}


# Overrides
SECURITY_MANAGER_CLASS = CognitoAuthorizer
OAuth2Client._fetch_token = fetch_token
OpenIDMixin.fetch_jwk_set = fetch_jwk
11 changes: 11 additions & 0 deletions airflow/helm/values.tmpl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ webserverSecretKeySecretName: ${webserver_secret_name}
webserver:
replicas: 3

webserverConfig: |-
${webserver_config}

startupProbe:
timeoutSeconds: 20
failureThreshold: 60 # Number of tries before giving up (10 minutes with periodSeconds of 10)
Expand Down Expand Up @@ -275,6 +278,14 @@ env:
value: "${karpenter_node_pools}"
- name: "AIRFLOW_VAR_ECR_URI"
value: "${cwl_dag_ecr_uri}"
- name: "COGNITO_CLIENT_ID"
value: "${cognito_client_id}"
- name: "COGNITO_CLIENT_SECRET"
value: "${cognito_client_secret}"
- name: "COGNITO_BASE_URL"
value: "${cognito_base_url}"
- name: "COGNITO_USER_POOL_ID"
value: "${cognito_user_pool_id}"

# https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/security/api.html
extraEnv: |
Expand Down
26 changes: 13 additions & 13 deletions terraform-unity/.terraform.lock.hcl

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions terraform-unity/modules/terraform-unity-sps-airflow/data.tf
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,27 @@ data "aws_secretsmanager_secret_version" "db" {
data "aws_efs_file_system" "efs" {
file_system_id = var.efs_file_system_id
}

data "aws_ssm_parameter" "ssl_cert_arn" {
name = "/unity/account/network/ssl"
}

data "aws_ssm_parameter" "ss_acct_num" {
name = "/unity/shared-services/aws/account"
}

data "aws_ssm_parameter" "cognito_base_url" {
name = "arn:aws:ssm:us-west-2:${data.aws_ssm_parameter.ss_acct_num.value}:parameter/unity/shared-services/cognito/base-url"
}

data "aws_ssm_parameter" "cognito_client_id" {
name = "arn:aws:ssm:us-west-2:${data.aws_ssm_parameter.ss_acct_num.value}:parameter/unity/shared-services/cognito/airflow-ui-client-id"
}

data "aws_ssm_parameter" "cognito_client_secret" {
name = "arn:aws:ssm:us-west-2:${data.aws_ssm_parameter.ss_acct_num.value}:parameter/unity/shared-services/cognito/airflow-ui-client-secret"
}

data "aws_ssm_parameter" "cognito_user_pool_id" {
name = "arn:aws:ssm:us-west-2:${data.aws_ssm_parameter.ss_acct_num.value}:parameter/unity/shared-services/cognito/user-pool-id"
}
25 changes: 17 additions & 8 deletions terraform-unity/modules/terraform-unity-sps-airflow/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ resource "helm_release" "airflow" {
unity_cluster_name = data.aws_eks_cluster.cluster.name
karpenter_node_pools = join(",", var.karpenter_node_pools)
cwl_dag_ecr_uri = "${data.aws_caller_identity.current.account_id}.dkr.ecr.us-west-2.amazonaws.com"
webserver_config = indent(4, file("${path.module}/../../../airflow/config/webserver_config.py"))
cognito_client_id = data.aws_ssm_parameter.cognito_client_id.value
cognito_client_secret = data.aws_ssm_parameter.cognito_client_secret.value
cognito_base_url = data.aws_ssm_parameter.cognito_base_url.value
cognito_user_pool_id = data.aws_ssm_parameter.cognito_user_pool_id.value
})
]
set_sensitive {
Expand Down Expand Up @@ -464,10 +469,12 @@ resource "kubernetes_ingress_v1" "airflow_ingress" {
"alb.ingress.kubernetes.io/scheme" = "internet-facing"
"alb.ingress.kubernetes.io/target-type" = "ip"
"alb.ingress.kubernetes.io/subnets" = join(",", jsondecode(data.aws_ssm_parameter.subnet_ids.value)["public"])
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTP\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTPS\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/security-groups" = aws_security_group.airflow_ingress_sg.id
"alb.ingress.kubernetes.io/manage-backend-security-group-rules" = "true"
"alb.ingress.kubernetes.io/healthcheck-path" = "/health"
"alb.ingress.kubernetes.io/certificate-arn" = data.aws_ssm_parameter.ssl_cert_arn.value
"alb.ingress.kubernetes.io/ssl-policy" = "ELBSecurityPolicy-TLS13-1-2-2021-06"
}
}
spec {
Expand Down Expand Up @@ -501,10 +508,12 @@ resource "kubernetes_ingress_v1" "airflow_ingress_internal" {
"alb.ingress.kubernetes.io/scheme" = "internal"
"alb.ingress.kubernetes.io/target-type" = "ip"
"alb.ingress.kubernetes.io/subnets" = join(",", jsondecode(data.aws_ssm_parameter.subnet_ids.value)["private"])
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTP\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTPS\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/security-groups" = aws_security_group.airflow_ingress_sg_internal.id
"alb.ingress.kubernetes.io/manage-backend-security-group-rules" = "true"
"alb.ingress.kubernetes.io/healthcheck-path" = "/health"
"alb.ingress.kubernetes.io/certificate-arn" = data.aws_ssm_parameter.ssl_cert_arn.value
"alb.ingress.kubernetes.io/ssl-policy" = "ELBSecurityPolicy-TLS13-1-2-2021-06"
}
}
spec {
Expand Down Expand Up @@ -534,7 +543,7 @@ resource "aws_ssm_parameter" "airflow_ui_url" {
name = format("/%s", join("/", compact(["", var.project, var.venue, var.service_area, "processing", "airflow", "ui_url"])))
description = "The URL of the Airflow UI."
type = "String"
value = "http://${data.kubernetes_ingress_v1.airflow_ingress.status[0].load_balancer[0].ingress[0].hostname}:5000"
value = "https://${data.kubernetes_ingress_v1.airflow_ingress.status[0].load_balancer[0].ingress[0].hostname}:5000"
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "endpoints-airflow_ui")
Component = "SSM"
Expand All @@ -548,8 +557,8 @@ resource "aws_ssm_parameter" "airflow_ui_health_check_endpoint" {
type = "String"
value = jsonencode({
"componentName" : "Airflow UI"
"healthCheckUrl" : "http://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000/health"
"landingPageUrl" : "http://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000"
"healthCheckUrl" : "https://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000/health"
"landingPageUrl" : "https://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000"
})
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "health-check-endpoints-airflow_ui")
Expand All @@ -565,7 +574,7 @@ resource "aws_ssm_parameter" "airflow_api_url" {
name = format("/%s", join("/", compact(["", var.project, var.venue, var.service_area, "processing", "airflow", "api_url"])))
description = "The URL of the Airflow REST API."
type = "String"
value = "http://${data.kubernetes_ingress_v1.airflow_ingress.status[0].load_balancer[0].ingress[0].hostname}:5000/api/v1"
value = "https://${data.kubernetes_ingress_v1.airflow_ingress.status[0].load_balancer[0].ingress[0].hostname}:5000/api/v1"
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "endpoints-airflow_api")
Component = "SSM"
Expand All @@ -579,8 +588,8 @@ resource "aws_ssm_parameter" "airflow_api_health_check_endpoint" {
type = "String"
value = jsonencode({
"componentName" : "Airflow API"
"healthCheckUrl" : "http://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000/api/v1/health"
"landingPageUrl" : "http://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000/api/v1"
"healthCheckUrl" : "https://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000/api/v1/health"
"landingPageUrl" : "https://${data.kubernetes_ingress_v1.airflow_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5000/api/v1"
})
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "health-check-endpoints-airflow_api")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ data "kubernetes_ingress_v1" "ogc_processes_api_ingress_internal" {
namespace = data.kubernetes_namespace.service_area.metadata[0].name
}
}

data "aws_ssm_parameter" "ssl_cert_arn" {
name = "/unity/account/network/ssl"
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,12 @@ resource "kubernetes_ingress_v1" "ogc_processes_api_ingress" {
"alb.ingress.kubernetes.io/scheme" = "internet-facing"
"alb.ingress.kubernetes.io/target-type" = "ip"
"alb.ingress.kubernetes.io/subnets" = join(",", jsondecode(data.aws_ssm_parameter.subnet_ids.value)["public"])
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTP\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTPS\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/security-groups" = aws_security_group.ogc_ingress_sg.id
"alb.ingress.kubernetes.io/manage-backend-security-group-rules" = "true"
"alb.ingress.kubernetes.io/healthcheck-path" = "/health"
"alb.ingress.kubernetes.io/certificate-arn" = data.aws_ssm_parameter.ssl_cert_arn.value
"alb.ingress.kubernetes.io/ssl-policy" = "ELBSecurityPolicy-TLS13-1-2-2021-06"
}
}
spec {
Expand Down Expand Up @@ -303,10 +305,12 @@ resource "kubernetes_ingress_v1" "ogc_processes_api_ingress_internal" {
"alb.ingress.kubernetes.io/scheme" = "internal"
"alb.ingress.kubernetes.io/target-type" = "ip"
"alb.ingress.kubernetes.io/subnets" = join(",", jsondecode(data.aws_ssm_parameter.subnet_ids.value)["private"])
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTP\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/listen-ports" = "[{\"HTTPS\": ${local.load_balancer_port}}]"
"alb.ingress.kubernetes.io/security-groups" = aws_security_group.ogc_ingress_sg_internal.id
"alb.ingress.kubernetes.io/manage-backend-security-group-rules" = "true"
"alb.ingress.kubernetes.io/healthcheck-path" = "/health"
"alb.ingress.kubernetes.io/certificate-arn" = data.aws_ssm_parameter.ssl_cert_arn.value
"alb.ingress.kubernetes.io/ssl-policy" = "ELBSecurityPolicy-TLS13-1-2-2021-06"
}
}
spec {
Expand Down Expand Up @@ -335,7 +339,7 @@ resource "aws_ssm_parameter" "ogc_processes_ui_url" {
name = format("/%s", join("/", compact(["", var.project, var.venue, var.service_area, "processing", "ogc_processes", "ui_url"])))
description = "The URL of the OGC Proccesses API Docs UI."
type = "String"
value = "http://${data.kubernetes_ingress_v1.ogc_processes_api_ingress.status[0].load_balancer[0].ingress[0].hostname}:5001/redoc"
value = "https://${data.kubernetes_ingress_v1.ogc_processes_api_ingress.status[0].load_balancer[0].ingress[0].hostname}:5001/redoc"
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "endpoints-ogc_processes_ui")
Component = "SSM"
Expand All @@ -347,7 +351,7 @@ resource "aws_ssm_parameter" "ogc_processes_api_url" {
name = format("/%s", join("/", compact(["", var.project, var.venue, var.service_area, "processing", "ogc_processes", "api_url"])))
description = "The URL of the OGC Processes REST API."
type = "String"
value = "http://${data.kubernetes_ingress_v1.ogc_processes_api_ingress.status[0].load_balancer[0].ingress[0].hostname}:5001"
value = "https://${data.kubernetes_ingress_v1.ogc_processes_api_ingress.status[0].load_balancer[0].ingress[0].hostname}:5001"
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "endpoints-ogc_processes_api")
Component = "SSM"
Expand All @@ -361,8 +365,8 @@ resource "aws_ssm_parameter" "ogc_processes_api_health_check_endpoint" {
type = "String"
value = jsonencode({
"componentName" : "OGC API"
"healthCheckUrl" : "http://${data.kubernetes_ingress_v1.ogc_processes_api_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5001/health"
"landingPageUrl" : "http://${data.kubernetes_ingress_v1.ogc_processes_api_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5001"
"healthCheckUrl" : "https://${data.kubernetes_ingress_v1.ogc_processes_api_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5001/health"
"landingPageUrl" : "https://${data.kubernetes_ingress_v1.ogc_processes_api_ingress_internal.status[0].load_balancer[0].ingress[0].hostname}:5001"
})
tags = merge(local.common_tags, {
Name = format(local.resource_name_prefix, "health-check-endpoints-ogc_processes_api")
Expand Down
Loading