diff --git a/charms/keystone-k8s/.sunbeam-build.yaml b/charms/keystone-k8s/.sunbeam-build.yaml index afe8091a..296931fa 100644 --- a/charms/keystone-k8s/.sunbeam-build.yaml +++ b/charms/keystone-k8s/.sunbeam-build.yaml @@ -7,6 +7,7 @@ external-libraries: - charms.tempo_k8s.v2.tracing - charms.tempo_k8s.v1.charm_tracing - charms.hydra.v0.oauth + - charms.kratos_external_idp_integrator.v0.kratos_external_provider internal-libraries: - charms.horizon_k8s.v0.trusted_dashboard templates: diff --git a/charms/keystone-k8s/charmcraft.yaml b/charms/keystone-k8s/charmcraft.yaml index 4dec3a08..96a0899a 100644 --- a/charms/keystone-k8s/charmcraft.yaml +++ b/charms/keystone-k8s/charmcraft.yaml @@ -168,6 +168,9 @@ requires: trusted-dashboard: interface: trusted-dashboard optional: true + external-idp: + interface: external_provider + optional: true provides: identity-service: diff --git a/charms/keystone-k8s/src/charm.py b/charms/keystone-k8s/src/charm.py index 087ce71c..5c3421cc 100755 --- a/charms/keystone-k8s/src/charm.py +++ b/charms/keystone-k8s/src/charm.py @@ -52,6 +52,7 @@ import charms.keystone_k8s.v0.domain_config as sunbeam_dc_svc import charms.keystone_k8s.v0.identity_credentials as sunbeam_cc_svc import charms.keystone_k8s.v0.identity_resource as sunbeam_ops_svc import charms.keystone_k8s.v1.identity_service as sunbeam_id_svc +import charms.kratos_external_idp_integrator.v0.kratos_external_provider as external_idp import jinja2 import keystoneauth1.exceptions import ops @@ -173,6 +174,18 @@ class KeystoneConfigAdapter(sunbeam_contexts.ConfigContext): } +@sunbeam_tracing.trace_type +class MergedFederatedIdentityConfigContext(sunbeam_contexts.ConfigContext): + """Config adapter that merges oauth and external_idp into one context.""" + + def context(self): + """Configuration context.""" + ctx = self.charm.merged_fid_contexts() + if not ctx or not ctx.get("oidc_providers", []): + return {} + return ctx + + @sunbeam_tracing.trace_type class IdentityServiceProvidesHandler(sunbeam_rhandlers.RelationHandler): """Handler for identity service relation.""" @@ -293,7 +306,60 @@ class IdentityResourceProvidesHandler(sunbeam_rhandlers.RelationHandler): return True -class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): +class _BaseIDPHandler(sunbeam_rhandlers.RelationHandler): + + def _get_idp_file_name_from_issuer_url(self, issuer_url): + """Generate a sanitized file name from the issuer URL. + + The openidc apache + module expects that the provider metadata exist in a folder defined by + OIDCMetadataDir and the file name must be a url encoded string, without + the schema and the trailing slash. + For example, if the issuer URL of the IDP is: + + https://172.16.1.207/iam-hydra + + then the generated file names will be: + + * 172.16.1.207%2Fiam-hydra.client - client ID and client secret + * 172.16.1.207%2Fiam-hydra.provider - the provider metadata file. + + The contents of the .provider file can be fetched from: + + https://172.16.1.207/iam-hydra/.well-known/openid-configuration + """ + sanitized = issuer_url.lstrip("https://").lstrip("http://").rstrip("/") + return quote(sanitized, safe="") + + def _get_oidc_metadata( + self, metadata_url, additional_chain: List[str] = [] + ): + try: + chains = self.charm._get_all_ca_bundles(additional_chain) + with tempfile.NamedTemporaryFile() as fd: + fd.write(chains.encode()) + fd.flush() + metadata = requests.get(metadata_url, verify=fd.name) + metadata.raise_for_status() + return metadata.json() + except Exception as err: + logger.error( + f"failed to retrieve idp metadata from {metadata_url}: {err}" + ) + raise sunbeam_guard.BlockedExceptionError( + f"failed to retrieve idp metadata from {metadata_url}" + ) from err + + @property + def oidc_redirect_uri(self): + """Generate the OIDC redirect URI.""" + return urljoin( + self.charm.public_endpoint.rstrip("/") + "/", + "OS-FEDERATION/protocols/openid/redirect_uri", + ) + + +class OAuthRequiresHandler(_BaseIDPHandler): """Handler for oauth relation.""" def setup_event_handler(self) -> ops.framework.Object: @@ -309,15 +375,12 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): self.charm.on.oauth_relation_changed, self._oauth_relation_changed, ) - return oauth - @property - def oidc_redirect_uri(self): - """Generate the OIDC redirect URI.""" - return urljoin( - self.charm.public_endpoint.rstrip("/") + "/", - "OS-FEDERATION/protocols/openid/redirect_uri", + self.framework.observe( + oauth.on.oauth_info_removed, + self._oauth_info_removed, ) + return oauth def _oauth_relation_changed(self, event): if not self.charm.unit.is_leader(): @@ -339,6 +402,9 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): return self.callback_f(event) + def _oauth_info_removed(self, event): + self.callback_f(event) + def get_oidc_providers(self): """Get all OIDC providers.""" data = [] @@ -370,53 +436,15 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): "name": relation.app.name, "protocol": "openid", "encoded_issuer_url": quote(provider_info.issuer_url, safe=""), - "info": provider_info, + "jwks_endpoint": provider_info.jwks_endpoint, + "issuer_url": provider_info.issuer_url, + "client_id": provider_info.client_id, + "client_secret": provider_info.client_secret, + "ca_chain": provider_info.ca_chain, } info.append(provider) return info - def _get_oidc_metadata( - self, metadata_url, additional_chain: List[str] = [] - ): - try: - chains = self.charm._get_all_ca_bundles(additional_chain) - with tempfile.NamedTemporaryFile() as fd: - fd.write(chains.encode()) - fd.flush() - metadata = requests.get(metadata_url, verify=fd.name) - metadata.raise_for_status() - return metadata.json() - except Exception as err: - logger.error( - f"failed to retrieve idp metadata from {metadata_url}: {err}" - ) - raise sunbeam_guard.BlockedExceptionError( - f"failed to retrieve idp metadata from {metadata_url}" - ) from err - - def _get_idp_file_name_from_issuer_url(self, issuer_url): - """Generate a sanitized file name from the issuer URL. - - The openidc apache - module expects that the provider metadata exist in a folder defined by - OIDCMetadataDir and the file name must be a url encoded string, without - the schema and the trailing slash. - For example, if the issuer URL of the IDP is: - - https://172.16.1.207/iam-hydra - - then the generated file names will be: - - * 172.16.1.207%2Fiam-hydra.client - client ID and client secret - * 172.16.1.207%2Fiam-hydra.provider - the provider metadata file. - - The contents of the .provider file can be fetched from: - - https://172.16.1.207/iam-hydra/.well-known/openid-configuration - """ - sanitized = issuer_url.lstrip("https://").lstrip("http://").rstrip("/") - return quote(sanitized, safe="") - def get_provider_metadata_files(self) -> Mapping[str, str]: """Get OIDC metadata files.""" all_providers = self.get_all_provider_info() @@ -424,20 +452,17 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): return {} files = {} for provider in all_providers: - provider_info = provider.get("info", None) - if not provider_info: - continue provider_metadata_url = urljoin( - provider_info.issuer_url.rstrip("/") + "/", + provider["issuer_url"].rstrip("/") + "/", ".well-known/openid-configuration", ) metadata = self._get_oidc_metadata( - provider_metadata_url, provider_info.ca_chain or [] + provider_metadata_url, provider["ca_chain"] or [] ) if not metadata: continue base_file_name = self._get_idp_file_name_from_issuer_url( - provider_info.issuer_url + provider["issuer_url"] ) provider_file = f"{base_file_name}.provider" client_file = f"{base_file_name}.client" @@ -445,8 +470,8 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): files[provider_file] = json.dumps(metadata) files[client_file] = json.dumps( { - "client_id": provider_info.client_id, - "client_secret": provider_info.client_secret, + "client_id": provider["client_id"], + "client_secret": provider["client_secret"], }, ) return files @@ -476,6 +501,148 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler): return False +class ExternalIDPRequiresHandler(_BaseIDPHandler): + """Handler for external-idp relation.""" + + def setup_event_handler(self) -> ops.framework.Object: + """Configure event handlers for external-idp relation.""" + logger.debug("Setting up the external-idp event handler") + + idp = external_idp.ExternalIdpRequirer( + self.charm, + self.relation_name, + ) + self.framework.observe( + idp.on.client_config_changed, + self._set_redirect_uri, + ) + + self.framework.observe( + idp.on.client_config_removed, + self._on_config_changed, + ) + + self.framework.observe( + self.charm.on.external_idp_relation_changed, + self._on_config_changed, + ) + + self.framework.observe( + self.charm.on.external_idp_relation_broken, + self._on_config_changed, + ) + return idp + + def get_oidc_providers(self): + """Get all OIDC providers.""" + providers = self.interface.get_providers() + + data = [] + for provider in providers: + data.append( + { + "name": provider.id, + "protocol": "openid", + "description": provider.label, + } + ) + if not data: + return {} + return {"federated-providers": data} + + def get_all_provider_info(self): + """Get all OIDC provider configs.""" + providers = self.interface.get_providers() + if not providers: + return {} + info = [] + for provider in providers: + if not provider.client_id or not provider.client_secret: + continue + + if not provider.issuer_url: + continue + + provider_metadata_url = urljoin( + provider.issuer_url.rstrip("/") + "/", + ".well-known/openid-configuration", + ) + metadata = self._get_oidc_metadata(provider_metadata_url, []) + if not metadata: + continue + + provider = { + "name": provider.id, + "protocol": "openid", + "encoded_issuer_url": quote(provider.issuer_url, safe=""), + "jwks_endpoint": metadata["jwks_uri"], + "issuer_url": provider.issuer_url, + "client_id": provider.client_id, + "client_secret": provider.client_secret, + "ca_chain": [], + "oidc_metadata": metadata, + } + info.append(provider) + return info + + def get_provider_metadata_files(self) -> Mapping[str, str]: + """Get OIDC metadata files.""" + providers = self.get_all_provider_info() + if not providers: + return {} + files = {} + for provider in providers: + if not provider["oidc_metadata"]: + continue + + base_file_name = self._get_idp_file_name_from_issuer_url( + provider["issuer_url"] + ) + provider_file = f"{base_file_name}.provider" + client_file = f"{base_file_name}.client" + + files[provider_file] = json.dumps(provider["oidc_metadata"]) + files[client_file] = json.dumps( + { + "client_id": provider["client_id"], + "client_secret": provider["client_secret"], + }, + ) + return files + + def _set_redirect_uri(self, event): + self.interface.set_relation_registered_provider( + self.oidc_redirect_uri, + event.provider_id, + event.relation_id, + ) + self.callback_f(event) + + def _on_config_changed(self, event): + self.callback_f(event) + + def ready(self): + """Check if handler is ready.""" + return True + + def context(self): + """Configuration context.""" + providers = self.get_all_provider_info() + if not providers: + return {} + + oidc_secret = self.charm.get_oidc_secret() + if not oidc_secret: + return {} + return { + "oidc_providers": providers, + "oidc_crypto_passphrase": oidc_secret, + "redirect_uri": self.oidc_redirect_uri, + "redirect_uri_path": urlparse(self.oidc_redirect_uri).path, + "public_url_path": urlparse(self.charm.public_endpoint).path, + } + + class TrustedDashboardRequiresHandler(sunbeam_rhandlers.RelationHandler): """Handler for trusted-dashboard relation.""" @@ -573,6 +740,7 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm): SEND_CA_CERT_RELATION_NAME = "send-ca-cert" RECEIVE_CA_CERT_RELATION_NAME = "receive-ca-cert" TRUSTED_DASHBOARD = "trusted-dashboard" + EXTERNAL_IDP = "external-idp" def __init__(self, framework): super().__init__(framework) @@ -622,6 +790,22 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm): self._list_ca_certs_action, ) + def merged_fid_contexts(self): + """Create a merged context from oauth and external_idp.""" + oidc_ctx = self.oauth.context() + external_idp_ctx = self.external_idp.context() + ctx = { + "oidc_providers": [], + } + if oidc_ctx: + ctx.update(oidc_ctx) + if external_idp_ctx: + providers = external_idp_ctx.pop("oidc_providers", []) + if providers: + ctx["oidc_providers"].extend(providers) + ctx.update(external_idp_ctx) + return ctx + def _handle_trusted_dashboard_changed(self, event: RelationChangedEvent): self._handle_update_trusted_dashboard(event) if not self.trusted_dashboard.ready(): @@ -633,11 +817,23 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm): if not self.unit.is_leader(): logger.debug("Not leader, skipping trusted-dashboard info update") return - relation_data = self.oauth.get_oidc_providers() - if not relation_data: + oauth_providers = self.oauth.get_oidc_providers() + external_providers = self.external_idp.get_oidc_providers() + if not oauth_providers and not external_providers: logger.debug("No OAuth relations found, skipping update") return - self.trusted_dashboard.set_requirer_info(relation_data) + data = {"federated-providers": []} + if oauth_providers: + data["federated-providers"].extend( + oauth_providers.get("federated-providers", []) + ) + if external_providers: + data["federated-providers"].extend( + external_providers.get("federated-providers", []) + ) + if not data["federated-providers"]: + return + self.trusted_dashboard.set_requirer_info(data) def _handle_oauth_info_changed(self, event: RelationChangedEvent): """Handle OAuth info changed event.""" @@ -739,11 +935,16 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm): def sync_oidc_providers(self): """Sync OIDC provider metadata to apache2 OIDCMetadataDir.""" - files = self.oauth.get_provider_metadata_files() + files = {} + oauth_files = self.oauth.get_provider_metadata_files() + external_idp_files = self.external_idp.get_provider_metadata_files() + if oauth_files: + files.update(oauth_files) + if external_idp_files: + files.update(external_idp_files) logger.info("Writing oidc metadata files") self.keystone_manager.setup_oidc_metadata_folder() self.keystone_manager.write_oidc_metadata(files) - return True def get_oidc_secret(self): """Get the OIDC secret from the peers relation.""" @@ -1289,6 +1490,14 @@ export OS_AUTH_VERSION=3 self._handle_oauth_info_changed, ) handlers.append(self.oauth) + if self.can_add_handler(self.EXTERNAL_IDP, handlers): + self.external_idp = ExternalIDPRequiresHandler( + self, + self.EXTERNAL_IDP, + self._handle_oauth_info_changed, + ) + handlers.append(self.external_idp) + return super().get_relation_handlers(handlers) @property @@ -1299,6 +1508,7 @@ export OS_AUTH_VERSION=3 [ KeystoneConfigAdapter(self, "ks_config"), KeystoneLoggingAdapter(self, "ks_logging"), + MergedFederatedIdentityConfigContext(self, "fid"), ] ) return contexts diff --git a/charms/keystone-k8s/src/templates/apache2-oidc-params b/charms/keystone-k8s/src/templates/apache2-oidc-params index 509fc3c1..3d5ac921 100644 --- a/charms/keystone-k8s/src/templates/apache2-oidc-params +++ b/charms/keystone-k8s/src/templates/apache2-oidc-params @@ -1,4 +1,4 @@ -{% if oauth and oauth.oidc_providers -%} +{% if fid and fid.oidc_providers %} OIDCClaimPrefix "OIDC-" OIDCClaimDelimiter ";" OIDCResponseType code @@ -6,34 +6,37 @@ OIDCStateInputHeaders none OIDCXForwardedHeaders X-Forwarded-Proto X-Forwarded-Host X-Forwarded-Port OIDCSessionType client-cookie:persistent - OIDCCryptoPassphrase {{ oauth.oidc_crypto_passphrase }} + OIDCCryptoPassphrase {{ fid.oidc_crypto_passphrase }} OIDCMetadataDir /etc/apache2/oidc-metadata - OIDCRedirectURI {{ oauth.redirect_uri }} + OIDCRedirectURI {{ fid.redirect_uri }} - - - AuthType auth-openidc - Require valid-user - - -{% for provider in oauth.oidc_providers -%} - - + AuthType auth-openidc Require valid-user - - OAuth2TokenVerify jwks_uri {{provider.info.jwks_endpoint}} - Require claim iss:{{provider.info.issuer_url}} - SetEnv HTTP_OIDC_ISS {{provider.info.issuer_url}} - - Require valid-user +{% for provider in fid.oidc_providers %} + + AuthType auth-openidc + + Require valid-user + Require oauth2_claim iss:{{provider.issuer_url}} + + + OAuth2TokenVerify jwks_uri {{provider.jwks_endpoint}} + SetEnv HTTP_OIDC_ISS {{provider.issuer_url}} + + + AuthType openid-connect + + Require valid-user + Require claim iss:{{provider.issuer_url}} + - OIDCDiscoverURL {{ oauth.redirect_uri }}?iss={{provider.encoded_issuer_url}} - Require claim iss:{{provider.info.issuer_url}} + OIDCDiscoverURL {{ fid.redirect_uri }}?iss={{provider.encoded_issuer_url}} + OIDCUnAuthAction auth true + OIDCUnAutzAction auth true - -{% endfor -%} -{% endif -%} +{% endfor %} +{% endif %} diff --git a/charms/keystone-k8s/src/templates/section-federation b/charms/keystone-k8s/src/templates/section-federation index 5e9d9111..d8082257 100644 --- a/charms/keystone-k8s/src/templates/section-federation +++ b/charms/keystone-k8s/src/templates/section-federation @@ -1,11 +1,11 @@ -{% if trusted_dashboard and trusted_dashboard.dashboards %} +{% if fid and fid.oidc_providers and trusted_dashboard and trusted_dashboard.dashboards %} [federation] {% for dashboard_url in trusted_dashboard.dashboards -%} trusted_dashboard = {{ dashboard_url }} {% endfor -%} {% endif %} -{% if oauth and oauth.oidc_providers -%} +{% if fid and fid.oidc_providers -%} [openid] remote_id_attribute = HTTP_OIDC_ISS {% endif -%} diff --git a/libs/external/lib/charms/kratos_external_idp_integrator/v0/kratos_external_provider.py b/libs/external/lib/charms/kratos_external_idp_integrator/v0/kratos_external_provider.py new file mode 100644 index 00000000..ca629906 --- /dev/null +++ b/libs/external/lib/charms/kratos_external_idp_integrator/v0/kratos_external_provider.py @@ -0,0 +1,822 @@ +#!/usr/bin/env python3 +# Copyright 2022 Canonical Ltd. +# See LICENSE file for licensing details. + +"""# Interface library for Kratos external OIDC providers. + +This library wraps relation endpoints using the `kratos-external-idp` interface +and provides a Python API for both requesting Kratos to register the client credentials +and for communicating with an external provider. + +## Getting Started + +To get started using the library, you just need to fetch the library using `charmcraft`. + +```shell +cd some-charm +charmcraft fetch-lib charms.kratos_external_idp_integrator.v0.kratos_external_provider +``` + +To use the library from the provider side (KratosExternalIdpIntegrator): + +In the `metadata.yaml` of the charm, add the following: +```yaml +provides: + kratos-external-idp: + interface: external_provider + limit: 1 +``` + +Then, to initialise the library: + +```python +from charms.kratos_external_idp_integrator.v0.kratos_external_provider import ( + ExternalIdpProvider, InvalidConfigError +) +from ops.model import BlockedStatus + +class SomeCharm(CharmBase): + def __init__(self, *args): + # ... + self.external_idp_provider = ExternalIdpProvider(self, self.config) + + self.framework.observe(self.on.config_changed, self._on_config_changed) + self.framework.observe(self.external_idp_provider.on.ready, self._on_ready) + self.framework.observe( + self.external_idp_provider.on.redirect_uri_changed, self._on_redirect_uri_changed + ) + + def _on_config_changed(self, event): + # ... + try: + self.external_idp_provider.validate_provider_config(self.config) + except InvalidConfigError as e: + self.unit.status = BlockedStatus(f"Invalid configuration: {e.args[0]}") + + # ... + + def _on_redirect_uri_changed(self, event): + logger.info(f"The client's redirect_uri changed to {event.redirect_uri}") + self._stored.redirect_uri = event.redirect_uri + self._on_update_status(event) + + def _on_ready(self, event): + if not isinstance(self.unit.status, BlockedStatus): + self.external_idp_provider.create_provider(self.config) +``` + +To use the library from the requirer side (Kratos): + +In the `metadata.yaml` of the charm, add the following: +```yaml +requires: + kratos-external-idp: + interface: external_provider +``` + +Then, to initialise the library: + +```python +from charms.kratos_external_idp_integrator.v0.kratos_external_provider import ( + ExternalIdpRequirer +) + +class KratosCharm(CharmBase): + def __init__(self, *args): + # ... + self.external_idp_requirer = ExternalIdpRequirer(self) + + self.framework.observe( + self.external_idp_provider.on.client_config_changed, self._on_client_config_changed + ) + + def _on_client_config_changed(self, event): + self._configure(event) + + self.external_provider.set_relation_registered_provider( + some_redirect_uri, event.provider_id, event.relation_id + ) +``` +""" + +import base64 +import hashlib +import inspect +import json +import logging +from dataclasses import dataclass +from typing import Dict, List, Mapping, Optional, Type + +import jsonschema +from ops.charm import ( + CharmBase, + RelationChangedEvent, + RelationDepartedEvent, + RelationEvent, + RelationJoinedEvent, +) +from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents +from ops.model import Relation, TooManyRelatedAppsError + +# The unique Charmhub library identifier, never change it +LIBID = "33040051de7f43a8bb43349f2b037dfc" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 11 + +PYDEPS = ["jsonschema"] + +DEFAULT_RELATION_NAME = "kratos-external-idp" +logger = logging.getLogger(__name__) + +PROVIDER_PROVIDERS_JSON_SCHEMA = { + "type": "array", + "items": { + "anyOf": [ + { + "type": "object", + "properties": { + "provider": { + "type": "string", + "enum": [ + "generic", + "google", + "facebook", + "microsoft", + "github", + "apple", + "gitlab", + "auth0", + "slack", + "spotify", + "discord", + "twitch", + "netid", + "yander", + "vk", + "dingtalk", + ], + }, + "client_id": {"type": "string"}, + "client_secret": {"type": "string"}, + "secret_backend": {"type": "string"}, + "issuer_url": {"type": "string"}, + "tenant_id": {"type": "string"}, + "private_key": {"type": "string"}, + "private_key_id": {"type": "string"}, + "scope": {"type": "string"}, + "team_id": {"type": "string"}, + "provider_id": {"type": "string"}, + "label": {"type": "string"}, + "jsonnet_mapper": {"type": "string"}, + }, + "additionalProperties": True, + }, + ], + }, +} + +PROVIDER_JSON_SCHEMA = { + "type": "object", + "properties": { + "providers": PROVIDER_PROVIDERS_JSON_SCHEMA, + }, +} + +REQUIRER_PROVIDERS_JSON_SCHEMA = { + "type": "array", + "items": { + "anyOf": [ + { + "type": "object", + "properties": { + "provider_id": {"type": "string"}, + "redirect_uri": {"type": "string"}, + }, + "additionalProperties": True, + }, + ] + }, +} + +REQUIRER_JSON_SCHEMA = { + "type": "object", + "properties": { + "providers": REQUIRER_PROVIDERS_JSON_SCHEMA, + }, +} + + +class InvalidConfigError(Exception): + """Internal exception that is raised if the charm config is not valid.""" + + +class DataValidationError(RuntimeError): + """Raised when data validation fails on relation data.""" + + +def _load_data(data: Dict, schema: Dict) -> Dict: + """Parses nested fields and checks whether `data` matches `schema`.""" + if "providers" not in data: + return {"providers": []} + + data = dict(data) + try: + data["providers"] = json.loads(data["providers"]) + except json.JSONDecodeError as e: + raise DataValidationError(f"Failed to decode relation json: {e}") + + _validate_data(data, schema) + return data + + +def _dump_data(data: Dict, schema: Dict) -> Dict: + _validate_data(data, schema) + + data = dict(data) + try: + data["providers"] = json.dumps(data["providers"]) + except json.JSONDecodeError as e: + raise DataValidationError(f"Failed to encode relation json: {e}") + return data + + +def _validate_data(data: Dict, schema: Dict) -> None: + """Checks whether `data` matches `schema`. + + Will raise DataValidationError if the data is not valid, else return None. + """ + try: + jsonschema.validate(instance=data, schema=schema) + except jsonschema.ValidationError as e: + raise DataValidationError(data, schema) from e + + +class BaseProviderConfigHandler: + """The base class for parsing a provider's config.""" + + mandatory_fields = {"provider", "client_id", "secret_backend"} + optional_fields = {"provider_id", "jsonnet_mapper", "label", "scope"} + excluded_fields = {"enabled"} + default_scope = "profile email address phone" + providers: List[str] = [] + + @classmethod + def validate_config(cls, config: Mapping) -> Dict: + """Validate and sanitize the user provided config.""" + config_keys = set(config.keys()) + provider = config["provider"] + if provider not in cls.providers: + raise ValueError(f"Invalid provider, allowed providers are: {cls.providers}") + + for key in cls.mandatory_fields: + if not config.get(key, None): + raise InvalidConfigError( + f"Missing required configuration '{key}' for provider '{config['provider']}'" + ) + config_keys.remove(key) + + for key in cls.optional_fields: + config_keys.discard(key) + + if config["secret_backend"] not in ["relation", "secret", "vault"]: + raise InvalidConfigError( + f"Invalid value {config['secret_backend']} for `secret_backend` " + "allowed values are: ['relation', 'secret', 'vault']" + ) + + for key in config_keys: + if key not in cls.excluded_fields: + logger.warn(f"Invalid config '{key}' for provider '{provider}' will be ignored") + + return {key: value for key, value in config.items() if key not in config_keys} + + @classmethod + def handle_config(cls, config: Mapping) -> List: + """Validate the config and transform it in the relation databag expected format.""" + config = cls.validate_config(config) + return cls.parse_config(config) + + @classmethod + def parse_config(cls, config: Dict) -> List: + """Parse the user provided config into the relation databag expected format.""" + ret = { + "client_id": config["client_id"], + "provider": config["provider"], + "secret_backend": config["secret_backend"], + "scope": config.get("scope", cls.default_scope), + } + ret.update({k: config[k] for k in cls.optional_fields if k in config}) + ret.update(cls._parse_provider_config(config)) + return [ret] + + @classmethod + def _parse_provider_config(cls, config: Dict) -> Dict: + """Create the provider specific config.""" + raise NotImplementedError() + + +class GenericConfigHandler(BaseProviderConfigHandler): + """The class for parsing a 'generic' provider's config.""" + + mandatory_fields = BaseProviderConfigHandler.mandatory_fields | {"client_secret", "issuer_url"} + providers = ["generic", "auth0"] + + @classmethod + def _parse_provider_config(cls, config: Dict) -> Dict: + return { + "client_secret": config["client_secret"], + "issuer_url": config["issuer_url"], + } + + +class SocialConfigHandler(BaseProviderConfigHandler): + """The class for parsing a social provider's config.""" + + mandatory_fields = BaseProviderConfigHandler.mandatory_fields | {"client_secret"} + providers = [ + "google", + "facebook", + "gitlab", + "slack", + "spotify", + "discord", + "twitch", + "netid", + "yander", + "vk", + "dingtalk", + ] + + @classmethod + def _parse_provider_config(cls, config: Dict) -> Dict: + return { + "client_secret": config["client_secret"], + } + + +class MicrosoftConfigHandler(SocialConfigHandler): + """The class for parsing a 'microsoft' provider's config.""" + + mandatory_fields = SocialConfigHandler.mandatory_fields | { + "microsoft_tenant_id", + } + providers = ["microsoft"] + + @classmethod + def _parse_provider_config(cls, config: Dict) -> Dict: + return { + "client_secret": config["client_secret"], + "tenant_id": config["microsoft_tenant_id"], + } + + @classmethod + def _parse_relation_data(cls, data: Dict) -> Dict: + return { + "client_secret": data["client_secret"], + "tenant_id": data["tenant_id"], + } + + +class GithubConfigHandler(SocialConfigHandler): + """The class for parsing a 'github' provider's config.""" + + default_scope = "user:email" + providers = ["github"] + + +class AppleConfigHandler(BaseProviderConfigHandler): + """The class for parsing an 'apple' provider's config.""" + + mandatory_fields = BaseProviderConfigHandler.mandatory_fields | { + "apple_team_id", + "apple_private_key_id", + "apple_private_key", + } + _secret_fields = ["private_key"] + providers = ["apple"] + + @classmethod + def _parse_provider_config(cls, config: Dict) -> Dict: + return { + "team_id": config["apple_team_id"], + "private_key_id": config["apple_private_key_id"], + "private_key": config["apple_private_key"], + } + + +_config_handlers = [ + GenericConfigHandler, + SocialConfigHandler, + MicrosoftConfigHandler, + GithubConfigHandler, + AppleConfigHandler, +] +allowed_providers = { + provider: handler for handler in _config_handlers for provider in handler.providers +} + + +def get_provider_config_handler(config: Mapping) -> Type[BaseProviderConfigHandler]: + """Get the config handler for this provider.""" + provider = config.get("provider") + if provider not in allowed_providers: + raise InvalidConfigError( + "Required configuration 'provider' MUST be one of the following: " + + ", ".join(allowed_providers) + ) + return allowed_providers[provider] + + +class RelationReadyEvent(EventBase): + """Event to notify the charm that the relation is ready.""" + + def snapshot(self) -> Dict: + """Save event.""" + return {} + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + pass + + +class RedirectURIChangedEvent(EventBase): + """Event to notify the charm that the redirect_uri changed.""" + + def __init__(self, handle: Handle, redirect_uri: str) -> None: + super().__init__(handle) + self.redirect_uri = redirect_uri + + def snapshot(self) -> Dict: + """Save redirect_uri.""" + return {"redirect_uri": self.redirect_uri} + + def restore(self, snapshot: Dict) -> None: + """Restore redirect_uri.""" + self.redirect_uri = snapshot["redirect_uri"] + + +class ExternalIdpProviderEvents(ObjectEvents): + """Event descriptor for events raised by `ExternalIdpProvider`.""" + + ready = EventSource(RelationReadyEvent) + redirect_uri_changed = EventSource(RedirectURIChangedEvent) + + +class ExternalIdpProvider(Object): + """Forward client configurations to Identity Broker.""" + + on = ExternalIdpProviderEvents() + + def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None: + super().__init__(charm, relation_name) + self._charm = charm + self._relation_name = relation_name + + events = self._charm.on[relation_name] + self.framework.observe(events.relation_joined, self._on_provider_endpoint_relation_joined) + self.framework.observe( + events.relation_changed, self._on_provider_endpoint_relation_changed + ) + self.framework.observe( + events.relation_departed, self._on_provider_endpoint_relation_departed + ) + + def _on_provider_endpoint_relation_joined(self, event: RelationJoinedEvent) -> None: + self.on.ready.emit() + + def _on_provider_endpoint_relation_changed(self, event: RelationChangedEvent) -> None: + if not event.app: + return + data = event.relation.data[event.app] + data = _load_data(data, REQUIRER_JSON_SCHEMA) + providers = data["providers"] + + if len(providers) == 0: + return + redirect_uri = providers[0].get("redirect_uri") + self.on.redirect_uri_changed.emit(redirect_uri=redirect_uri) + + def _on_provider_endpoint_relation_departed(self, event: RelationDepartedEvent) -> None: + self.on.redirect_uri_changed.emit(redirect_uri="") + + def is_ready(self) -> bool: + """Checks if the relation is ready.""" + return self._charm.model.get_relation(self._relation_name) is not None + + def create_provider(self, config: Mapping) -> None: + """Use the configuration to create the relation databag.""" + if not self._charm.unit.is_leader(): + return + + config = self._handle_config(config) + return self._set_provider_data(config) + + def remove_provider(self) -> None: + """Remove the provider config to the relation databag.""" + if not self._charm.unit.is_leader(): + return + + # Do we need to iterate on the relations? There should never be more + # than one + for relation in self._charm.model.relations[self._relation_name]: + relation.data[self._charm.app].clear() + + def get_redirect_uri(self, relation_id: Optional[int] = None) -> Optional[str]: + """Get the kratos client's redirect_uri.""" + if not self.model.unit.is_leader(): + return None + + try: + relation = self.model.get_relation( + relation_name=self._relation_name, relation_id=relation_id + ) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + + if not relation or not relation.app: + return None + + data = relation.data[relation.app] + data = _load_data(data, REQUIRER_JSON_SCHEMA) + providers = data["providers"] + + if len(providers) == 0: + return None + + return providers[0].get("redirect_uri") + + def validate_provider_config(self, config: Mapping) -> None: + """Validate the provider config. + + Raises InvalidConfigError if config is invalid. + """ + self._validate_config(config) + + def _handle_config(self, config: Mapping) -> List: + handler = get_provider_config_handler(config) + return handler.handle_config(config) + + def _validate_config(self, config: Mapping) -> None: + handler = get_provider_config_handler(config) + handler.validate_config(config) + + def _set_provider_data(self, provider_config: List) -> None: + self._create_secrets(provider_config) + # Do we need to iterate on the relations? There should never be more + # than one + for relation in self._charm.model.relations[self._relation_name]: + relation.data[self._charm.app]["providers"] = json.dumps(provider_config) + + def _create_secrets(self, provider_config: List) -> None: + for conf in provider_config: + backend = conf["secret_backend"] + + if backend == "relation": + pass + elif backend == "secret": + raise NotImplementedError() + elif backend == "vault": + raise NotImplementedError() + else: + raise ValueError(f"Invalid backend: {backend}") + + +@dataclass +class Provider: + """Class for describing an external provider.""" + + client_id: str + provider: str + relation_id: Optional[str] = None + scope: str = "profile email address phone" + label: Optional[str] = None + client_secret: Optional[str] = None + issuer_url: Optional[str] = None + tenant_id: Optional[str] = None + microsoft_tenant: Optional[str] = None + team_id: Optional[str] = None + private_key_id: Optional[str] = None + private_key: Optional[str] = None + jsonnet_mapper: Optional[str] = None + id: Optional[str] = None + + @property + def provider_id(self) -> str: + """Returns a unique ID for the client credentials of the provider.""" + if self.id: + return self.id + + if self.issuer_url: + id = hashlib.sha1(f"{self.client_id}_{self.issuer_url}".encode()).hexdigest() + elif self.get_microsoft_tenant(): + id = hashlib.sha1(f"{self.client_id}_{self.tenant_id}".encode()).hexdigest() + else: + id = hashlib.sha1(self.client_id.encode()).hexdigest() + return f"{self.provider}_{id}" + + @provider_id.setter + def provider_id(self, val) -> None: + self.id = val + + def get_scope(self) -> list: + if isinstance(self.scope, str): + return self.scope.split(" ") + elif isinstance(self.scope, list): + return self.scope + else: + raise ValueError(f"scope must be `list` or `str`, but `{type(self.scope)}` provided") + + def get_microsoft_tenant(self) -> str: + return self.tenant_id or self.microsoft_tenant + + def config(self) -> Dict: + """Generate Kratos config for this provider.""" + ret = { + "id": self.provider_id, + "client_id": self.client_id, + "provider": self.provider, + "label": self.label or self.provider, + "client_secret": self.client_secret, + "issuer_url": self.issuer_url, + "scope": self.get_scope(), + "mapper_url": ( + f"base64://{base64.b64encode(self.jsonnet_mapper.encode()).decode()}" + if self.jsonnet_mapper + else None + ), + "microsoft_tenant": self.get_microsoft_tenant(), + "apple_team_id": self.team_id, + "apple_private_key_id": self.private_key_id, + "apple_private_key": self.private_key, + } + return {k: v for k, v in ret.items() if v} + + @classmethod + def from_dict(cls, dic: Dict) -> "Provider": + """Generate Provider instance from dict.""" + if provider_id := dic.get("provider_id"): + dic["id"] = provider_id + return cls(**{k: v for k, v in dic.items() if k in inspect.signature(cls).parameters}) + + +class ClientConfigChangedEvent(EventBase): + """Event to notify the charm that a provider's client config changed.""" + + def __init__(self, handle: Handle, provider: Provider) -> None: + super().__init__(handle) + self.client_id = provider.client_id + self.provider = provider.provider + self.provider_id = provider.provider_id + self.relation_id = provider.relation_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "client_id": self.client_id, + "provider": self.provider, + "provider_id": self.provider_id, + "relation_id": self.relation_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.client_id = snapshot["client_id"] + self.provider = snapshot["provider"] + self.provider_id = snapshot["provider_id"] + self.relation_id = snapshot["relation_id"] + + +class ClientConfigRemovedEvent(EventBase): + """Event to notify the charm that a provider's client config was removed.""" + + def __init__(self, handle: Handle, relation_id: str) -> None: + super().__init__(handle) + self.relation_id = relation_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "relation_id": self.relation_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.relation_id = snapshot["relation_id"] + + +class ExternalIdpRequirerEvents(ObjectEvents): + """Event descriptor for events raised by `ExternalIdpRequirerEvents`.""" + + client_config_changed = EventSource(ClientConfigChangedEvent) + client_config_removed = EventSource(ClientConfigRemovedEvent) + + +class ExternalIdpRequirer(Object): + """Receive the External Idp configurations for Kratos.""" + + on = ExternalIdpRequirerEvents() + + def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None: + super().__init__(charm, relation_name) + self._charm = charm + self._relation_name = relation_name + + events = self._charm.on[relation_name] + self.framework.observe( + events.relation_changed, self._on_provider_endpoint_relation_changed + ) + self.framework.observe( + events.relation_departed, self._on_provider_endpoint_relation_changed + ) + + def _on_provider_endpoint_relation_changed(self, event: RelationEvent) -> None: + if not event.app: + return + + data = event.relation.data[event.app] + data = _load_data(data, PROVIDER_JSON_SCHEMA) + providers = data["providers"] + + if len(providers) == 0: + self.on.client_config_removed.emit(event.relation.id) + return + + p = self._get_provider(providers[0], event.relation) + self.on.client_config_changed.emit(p) + + def set_relation_registered_provider( + self, redirect_uri: str, provider_id: str, relation_id: int + ) -> None: + """Update the relation databag.""" + if not self._charm.unit.is_leader(): + return + + data = { + "providers": [ + { + "redirect_uri": redirect_uri, + "provider_id": provider_id, + } + ] + } + + data = _dump_data(data, REQUIRER_JSON_SCHEMA) + + relation = self.model.get_relation( + relation_name=self._relation_name, relation_id=relation_id + ) + if not relation: + return + relation.data[self.model.app].update(data) + + def remove_relation_registered_provider(self, relation_id: int) -> None: + """Delete the provider info from the databag.""" + if not self._charm.unit.is_leader(): + return + + relation = self.model.get_relation( + relation_name=self._relation_name, relation_id=relation_id + ) + if not relation: + return + relation.data[self.model.app].clear() + + def get_providers(self) -> List: + """Iterate over the relations and fetch all providers.""" + providers = [] + # For each relation get the client credentials and compile them into a + # single object + for relation in self.model.relations[self._relation_name]: + if not relation.app: + continue + data = relation.data[relation.app] + data = _load_data(data, PROVIDER_JSON_SCHEMA) + for p in data["providers"]: + provider = self._get_provider(p, relation) + providers.append(provider) + + return providers + + def _get_provider(self, provider: Dict, relation: Relation) -> Provider: + provider = self._extract_secrets(provider) + provider["relation_id"] = relation.id + provider = Provider.from_dict(provider) + return provider + + def _extract_secrets(self, data: Dict) -> Dict: + backend = data["secret_backend"] + + if backend == "relation": + pass + elif backend == "secret": + raise NotImplementedError() + elif backend == "vault": + raise NotImplementedError() + else: + raise ValueError(f"Invalid backend: {backend}") + return data