Add the external IDP relation

This change adds the ability to integrate external OIDC IDPs such
as google, okta, auth0, etc with keystone. This is done through the
external_provider interface which already has a provider implementation
via the kratos-external-idp-integrator charm.

Change-Id: Ib651976852f709a8879844027d8f4c4e594048ed
Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira
2025-06-20 22:52:33 +03:00
parent 09d6d4ab55
commit 51b83a30b7
6 changed files with 1128 additions and 89 deletions

View File

@@ -7,6 +7,7 @@ external-libraries:
- charms.tempo_k8s.v2.tracing
- charms.tempo_k8s.v1.charm_tracing
- charms.hydra.v0.oauth
- charms.kratos_external_idp_integrator.v0.kratos_external_provider
internal-libraries:
- charms.horizon_k8s.v0.trusted_dashboard
templates:

View File

@@ -168,6 +168,9 @@ requires:
trusted-dashboard:
interface: trusted-dashboard
optional: true
external-idp:
interface: external_provider
optional: true
provides:
identity-service:

View File

@@ -52,6 +52,7 @@ import charms.keystone_k8s.v0.domain_config as sunbeam_dc_svc
import charms.keystone_k8s.v0.identity_credentials as sunbeam_cc_svc
import charms.keystone_k8s.v0.identity_resource as sunbeam_ops_svc
import charms.keystone_k8s.v1.identity_service as sunbeam_id_svc
import charms.kratos_external_idp_integrator.v0.kratos_external_provider as external_idp
import jinja2
import keystoneauth1.exceptions
import ops
@@ -173,6 +174,18 @@ class KeystoneConfigAdapter(sunbeam_contexts.ConfigContext):
}
@sunbeam_tracing.trace_type
class MergedFederatedIdentityConfigContext(sunbeam_contexts.ConfigContext):
"""Config adapter that merges oauth and external_idp into one context."""
def context(self):
"""Configuration context."""
ctx = self.charm.merged_fid_contexts()
if not ctx or not ctx.get("oidc_providers", []):
return {}
return ctx
@sunbeam_tracing.trace_type
class IdentityServiceProvidesHandler(sunbeam_rhandlers.RelationHandler):
"""Handler for identity service relation."""
@@ -293,7 +306,60 @@ class IdentityResourceProvidesHandler(sunbeam_rhandlers.RelationHandler):
return True
class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
class _BaseIDPHandler(sunbeam_rhandlers.RelationHandler):
def _get_idp_file_name_from_issuer_url(self, issuer_url):
"""Generate a sanitized file name from the issuer URL.
The openidc apache
module expects that the provider metadata exist in a folder defined by
OIDCMetadataDir and the file name must be a url encoded string, without
the schema and the trailing slash.
For example, if the issuer URL of the IDP is:
https://172.16.1.207/iam-hydra
then the generated file names will be:
* 172.16.1.207%2Fiam-hydra.client - client ID and client secret
* 172.16.1.207%2Fiam-hydra.provider - the provider metadata file.
The contents of the .provider file can be fetched from:
https://172.16.1.207/iam-hydra/.well-known/openid-configuration
"""
sanitized = issuer_url.lstrip("https://").lstrip("http://").rstrip("/")
return quote(sanitized, safe="")
def _get_oidc_metadata(
self, metadata_url, additional_chain: List[str] = []
):
try:
chains = self.charm._get_all_ca_bundles(additional_chain)
with tempfile.NamedTemporaryFile() as fd:
fd.write(chains.encode())
fd.flush()
metadata = requests.get(metadata_url, verify=fd.name)
metadata.raise_for_status()
return metadata.json()
except Exception as err:
logger.error(
f"failed to retrieve idp metadata from {metadata_url}: {err}"
)
raise sunbeam_guard.BlockedExceptionError(
f"failed to retrieve idp metadata from {metadata_url}"
) from err
@property
def oidc_redirect_uri(self):
"""Generate the OIDC redirect URI."""
return urljoin(
self.charm.public_endpoint.rstrip("/") + "/",
"OS-FEDERATION/protocols/openid/redirect_uri",
)
class OAuthRequiresHandler(_BaseIDPHandler):
"""Handler for oauth relation."""
def setup_event_handler(self) -> ops.framework.Object:
@@ -309,15 +375,12 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
self.charm.on.oauth_relation_changed,
self._oauth_relation_changed,
)
return oauth
@property
def oidc_redirect_uri(self):
"""Generate the OIDC redirect URI."""
return urljoin(
self.charm.public_endpoint.rstrip("/") + "/",
"OS-FEDERATION/protocols/openid/redirect_uri",
self.framework.observe(
oauth.on.oauth_info_removed,
self._oauth_info_removed,
)
return oauth
def _oauth_relation_changed(self, event):
if not self.charm.unit.is_leader():
@@ -339,6 +402,9 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
return
self.callback_f(event)
def _oauth_info_removed(self, event):
self.callback_f(event)
def get_oidc_providers(self):
"""Get all OIDC providers."""
data = []
@@ -370,53 +436,15 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
"name": relation.app.name,
"protocol": "openid",
"encoded_issuer_url": quote(provider_info.issuer_url, safe=""),
"info": provider_info,
"jwks_endpoint": provider_info.jwks_endpoint,
"issuer_url": provider_info.issuer_url,
"client_id": provider_info.client_id,
"client_secret": provider_info.client_secret,
"ca_chain": provider_info.ca_chain,
}
info.append(provider)
return info
def _get_oidc_metadata(
self, metadata_url, additional_chain: List[str] = []
):
try:
chains = self.charm._get_all_ca_bundles(additional_chain)
with tempfile.NamedTemporaryFile() as fd:
fd.write(chains.encode())
fd.flush()
metadata = requests.get(metadata_url, verify=fd.name)
metadata.raise_for_status()
return metadata.json()
except Exception as err:
logger.error(
f"failed to retrieve idp metadata from {metadata_url}: {err}"
)
raise sunbeam_guard.BlockedExceptionError(
f"failed to retrieve idp metadata from {metadata_url}"
) from err
def _get_idp_file_name_from_issuer_url(self, issuer_url):
"""Generate a sanitized file name from the issuer URL.
The openidc apache
module expects that the provider metadata exist in a folder defined by
OIDCMetadataDir and the file name must be a url encoded string, without
the schema and the trailing slash.
For example, if the issuer URL of the IDP is:
https://172.16.1.207/iam-hydra
then the generated file names will be:
* 172.16.1.207%2Fiam-hydra.client - client ID and client secret
* 172.16.1.207%2Fiam-hydra.provider - the provider metadata file.
The contents of the .provider file can be fetched from:
https://172.16.1.207/iam-hydra/.well-known/openid-configuration
"""
sanitized = issuer_url.lstrip("https://").lstrip("http://").rstrip("/")
return quote(sanitized, safe="")
def get_provider_metadata_files(self) -> Mapping[str, str]:
"""Get OIDC metadata files."""
all_providers = self.get_all_provider_info()
@@ -424,20 +452,17 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
return {}
files = {}
for provider in all_providers:
provider_info = provider.get("info", None)
if not provider_info:
continue
provider_metadata_url = urljoin(
provider_info.issuer_url.rstrip("/") + "/",
provider["issuer_url"].rstrip("/") + "/",
".well-known/openid-configuration",
)
metadata = self._get_oidc_metadata(
provider_metadata_url, provider_info.ca_chain or []
provider_metadata_url, provider["ca_chain"] or []
)
if not metadata:
continue
base_file_name = self._get_idp_file_name_from_issuer_url(
provider_info.issuer_url
provider["issuer_url"]
)
provider_file = f"{base_file_name}.provider"
client_file = f"{base_file_name}.client"
@@ -445,8 +470,8 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
files[provider_file] = json.dumps(metadata)
files[client_file] = json.dumps(
{
"client_id": provider_info.client_id,
"client_secret": provider_info.client_secret,
"client_id": provider["client_id"],
"client_secret": provider["client_secret"],
},
)
return files
@@ -476,6 +501,148 @@ class OAuthRequiresHandler(sunbeam_rhandlers.RelationHandler):
return False
class ExternalIDPRequiresHandler(_BaseIDPHandler):
"""Handler for external-idp relation."""
def setup_event_handler(self) -> ops.framework.Object:
"""Configure event handlers for external-idp relation."""
logger.debug("Setting up the external-idp event handler")
idp = external_idp.ExternalIdpRequirer(
self.charm,
self.relation_name,
)
self.framework.observe(
idp.on.client_config_changed,
self._set_redirect_uri,
)
self.framework.observe(
idp.on.client_config_removed,
self._on_config_changed,
)
self.framework.observe(
self.charm.on.external_idp_relation_changed,
self._on_config_changed,
)
self.framework.observe(
self.charm.on.external_idp_relation_broken,
self._on_config_changed,
)
return idp
def get_oidc_providers(self):
"""Get all OIDC providers."""
providers = self.interface.get_providers()
data = []
for provider in providers:
data.append(
{
"name": provider.id,
"protocol": "openid",
"description": provider.label,
}
)
if not data:
return {}
return {"federated-providers": data}
def get_all_provider_info(self):
"""Get all OIDC provider configs."""
providers = self.interface.get_providers()
if not providers:
return {}
info = []
for provider in providers:
if not provider.client_id or not provider.client_secret:
continue
if not provider.issuer_url:
continue
provider_metadata_url = urljoin(
provider.issuer_url.rstrip("/") + "/",
".well-known/openid-configuration",
)
metadata = self._get_oidc_metadata(provider_metadata_url, [])
if not metadata:
continue
provider = {
"name": provider.id,
"protocol": "openid",
"encoded_issuer_url": quote(provider.issuer_url, safe=""),
"jwks_endpoint": metadata["jwks_uri"],
"issuer_url": provider.issuer_url,
"client_id": provider.client_id,
"client_secret": provider.client_secret,
"ca_chain": [],
"oidc_metadata": metadata,
}
info.append(provider)
return info
def get_provider_metadata_files(self) -> Mapping[str, str]:
"""Get OIDC metadata files."""
providers = self.get_all_provider_info()
if not providers:
return {}
files = {}
for provider in providers:
if not provider["oidc_metadata"]:
continue
base_file_name = self._get_idp_file_name_from_issuer_url(
provider["issuer_url"]
)
provider_file = f"{base_file_name}.provider"
client_file = f"{base_file_name}.client"
files[provider_file] = json.dumps(provider["oidc_metadata"])
files[client_file] = json.dumps(
{
"client_id": provider["client_id"],
"client_secret": provider["client_secret"],
},
)
return files
def _set_redirect_uri(self, event):
self.interface.set_relation_registered_provider(
self.oidc_redirect_uri,
event.provider_id,
event.relation_id,
)
self.callback_f(event)
def _on_config_changed(self, event):
self.callback_f(event)
def ready(self):
"""Check if handler is ready."""
return True
def context(self):
"""Configuration context."""
providers = self.get_all_provider_info()
if not providers:
return {}
oidc_secret = self.charm.get_oidc_secret()
if not oidc_secret:
return {}
return {
"oidc_providers": providers,
"oidc_crypto_passphrase": oidc_secret,
"redirect_uri": self.oidc_redirect_uri,
"redirect_uri_path": urlparse(self.oidc_redirect_uri).path,
"public_url_path": urlparse(self.charm.public_endpoint).path,
}
class TrustedDashboardRequiresHandler(sunbeam_rhandlers.RelationHandler):
"""Handler for trusted-dashboard relation."""
@@ -573,6 +740,7 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm):
SEND_CA_CERT_RELATION_NAME = "send-ca-cert"
RECEIVE_CA_CERT_RELATION_NAME = "receive-ca-cert"
TRUSTED_DASHBOARD = "trusted-dashboard"
EXTERNAL_IDP = "external-idp"
def __init__(self, framework):
super().__init__(framework)
@@ -622,6 +790,22 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm):
self._list_ca_certs_action,
)
def merged_fid_contexts(self):
"""Create a merged context from oauth and external_idp."""
oidc_ctx = self.oauth.context()
external_idp_ctx = self.external_idp.context()
ctx = {
"oidc_providers": [],
}
if oidc_ctx:
ctx.update(oidc_ctx)
if external_idp_ctx:
providers = external_idp_ctx.pop("oidc_providers", [])
if providers:
ctx["oidc_providers"].extend(providers)
ctx.update(external_idp_ctx)
return ctx
def _handle_trusted_dashboard_changed(self, event: RelationChangedEvent):
self._handle_update_trusted_dashboard(event)
if not self.trusted_dashboard.ready():
@@ -633,11 +817,23 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm):
if not self.unit.is_leader():
logger.debug("Not leader, skipping trusted-dashboard info update")
return
relation_data = self.oauth.get_oidc_providers()
if not relation_data:
oauth_providers = self.oauth.get_oidc_providers()
external_providers = self.external_idp.get_oidc_providers()
if not oauth_providers and not external_providers:
logger.debug("No OAuth relations found, skipping update")
return
self.trusted_dashboard.set_requirer_info(relation_data)
data = {"federated-providers": []}
if oauth_providers:
data["federated-providers"].extend(
oauth_providers.get("federated-providers", [])
)
if external_providers:
data["federated-providers"].extend(
external_providers.get("federated-providers", [])
)
if not data["federated-providers"]:
return
self.trusted_dashboard.set_requirer_info(data)
def _handle_oauth_info_changed(self, event: RelationChangedEvent):
"""Handle OAuth info changed event."""
@@ -739,11 +935,16 @@ class KeystoneOperatorCharm(sunbeam_charm.OSBaseOperatorAPICharm):
def sync_oidc_providers(self):
"""Sync OIDC provider metadata to apache2 OIDCMetadataDir."""
files = self.oauth.get_provider_metadata_files()
files = {}
oauth_files = self.oauth.get_provider_metadata_files()
external_idp_files = self.external_idp.get_provider_metadata_files()
if oauth_files:
files.update(oauth_files)
if external_idp_files:
files.update(external_idp_files)
logger.info("Writing oidc metadata files")
self.keystone_manager.setup_oidc_metadata_folder()
self.keystone_manager.write_oidc_metadata(files)
return True
def get_oidc_secret(self):
"""Get the OIDC secret from the peers relation."""
@@ -1289,6 +1490,14 @@ export OS_AUTH_VERSION=3
self._handle_oauth_info_changed,
)
handlers.append(self.oauth)
if self.can_add_handler(self.EXTERNAL_IDP, handlers):
self.external_idp = ExternalIDPRequiresHandler(
self,
self.EXTERNAL_IDP,
self._handle_oauth_info_changed,
)
handlers.append(self.external_idp)
return super().get_relation_handlers(handlers)
@property
@@ -1299,6 +1508,7 @@ export OS_AUTH_VERSION=3
[
KeystoneConfigAdapter(self, "ks_config"),
KeystoneLoggingAdapter(self, "ks_logging"),
MergedFederatedIdentityConfigContext(self, "fid"),
]
)
return contexts

View File

@@ -1,4 +1,4 @@
{% if oauth and oauth.oidc_providers -%}
{% if fid and fid.oidc_providers %}
OIDCClaimPrefix "OIDC-"
OIDCClaimDelimiter ";"
OIDCResponseType code
@@ -6,34 +6,37 @@
OIDCStateInputHeaders none
OIDCXForwardedHeaders X-Forwarded-Proto X-Forwarded-Host X-Forwarded-Port
OIDCSessionType client-cookie:persistent
OIDCCryptoPassphrase {{ oauth.oidc_crypto_passphrase }}
OIDCCryptoPassphrase {{ fid.oidc_crypto_passphrase }}
OIDCMetadataDir /etc/apache2/oidc-metadata
OIDCRedirectURI {{ oauth.redirect_uri }}
OIDCRedirectURI {{ fid.redirect_uri }}
<Location {{ oauth.redirect_uri_path }}>
AuthType auth-openidc
Require valid-user
</Location>
{% for provider in oauth.oidc_providers -%}
<Location {{oauth.public_url_path}}/OS-FEDERATION/identity_providers/{{ provider.name }}/protocols/{{ provider.protocol }}/auth>
<Location {{ fid.redirect_uri_path }}>
AuthType auth-openidc
Require valid-user
OAuth2TokenVerify jwks_uri {{provider.info.jwks_endpoint}}
Require claim iss:{{provider.info.issuer_url}}
SetEnv HTTP_OIDC_ISS {{provider.info.issuer_url}}
</Location>
<Location {{oauth.public_url_path}}/auth/OS-FEDERATION/identity_providers/{{ provider.name }}/protocols/{{ provider.protocol }}/websso>
Require valid-user
{% for provider in fid.oidc_providers %}
<Location {{fid.public_url_path}}/OS-FEDERATION/identity_providers/{{ provider.name }}/protocols/{{ provider.protocol }}/auth>
AuthType auth-openidc
<RequireAll>
Require valid-user
Require oauth2_claim iss:{{provider.issuer_url}}
</RequireAll>
OAuth2TokenVerify jwks_uri {{provider.jwks_endpoint}}
SetEnv HTTP_OIDC_ISS {{provider.issuer_url}}
</Location>
<Location {{fid.public_url_path}}/auth/OS-FEDERATION/identity_providers/{{ provider.name }}/protocols/{{ provider.protocol }}/websso>
AuthType openid-connect
<RequireAll>
Require valid-user
Require claim iss:{{provider.issuer_url}}
</RequireAll>
OIDCDiscoverURL {{ oauth.redirect_uri }}?iss={{provider.encoded_issuer_url}}
Require claim iss:{{provider.info.issuer_url}}
OIDCDiscoverURL {{ fid.redirect_uri }}?iss={{provider.encoded_issuer_url}}
OIDCUnAuthAction auth true
OIDCUnAutzAction auth true
</Location>
{% endfor -%}
{% endif -%}
{% endfor %}
{% endif %}

View File

@@ -1,11 +1,11 @@
{% if trusted_dashboard and trusted_dashboard.dashboards %}
{% if fid and fid.oidc_providers and trusted_dashboard and trusted_dashboard.dashboards %}
[federation]
{% for dashboard_url in trusted_dashboard.dashboards -%}
trusted_dashboard = {{ dashboard_url }}
{% endfor -%}
{% endif %}
{% if oauth and oauth.oidc_providers -%}
{% if fid and fid.oidc_providers -%}
[openid]
remote_id_attribute = HTTP_OIDC_ISS
{% endif -%}

View File

@@ -0,0 +1,822 @@
#!/usr/bin/env python3
# Copyright 2022 Canonical Ltd.
# See LICENSE file for licensing details.
"""# Interface library for Kratos external OIDC providers.
This library wraps relation endpoints using the `kratos-external-idp` interface
and provides a Python API for both requesting Kratos to register the client credentials
and for communicating with an external provider.
## Getting Started
To get started using the library, you just need to fetch the library using `charmcraft`.
```shell
cd some-charm
charmcraft fetch-lib charms.kratos_external_idp_integrator.v0.kratos_external_provider
```
To use the library from the provider side (KratosExternalIdpIntegrator):
In the `metadata.yaml` of the charm, add the following:
```yaml
provides:
kratos-external-idp:
interface: external_provider
limit: 1
```
Then, to initialise the library:
```python
from charms.kratos_external_idp_integrator.v0.kratos_external_provider import (
ExternalIdpProvider, InvalidConfigError
)
from ops.model import BlockedStatus
class SomeCharm(CharmBase):
def __init__(self, *args):
# ...
self.external_idp_provider = ExternalIdpProvider(self, self.config)
self.framework.observe(self.on.config_changed, self._on_config_changed)
self.framework.observe(self.external_idp_provider.on.ready, self._on_ready)
self.framework.observe(
self.external_idp_provider.on.redirect_uri_changed, self._on_redirect_uri_changed
)
def _on_config_changed(self, event):
# ...
try:
self.external_idp_provider.validate_provider_config(self.config)
except InvalidConfigError as e:
self.unit.status = BlockedStatus(f"Invalid configuration: {e.args[0]}")
# ...
def _on_redirect_uri_changed(self, event):
logger.info(f"The client's redirect_uri changed to {event.redirect_uri}")
self._stored.redirect_uri = event.redirect_uri
self._on_update_status(event)
def _on_ready(self, event):
if not isinstance(self.unit.status, BlockedStatus):
self.external_idp_provider.create_provider(self.config)
```
To use the library from the requirer side (Kratos):
In the `metadata.yaml` of the charm, add the following:
```yaml
requires:
kratos-external-idp:
interface: external_provider
```
Then, to initialise the library:
```python
from charms.kratos_external_idp_integrator.v0.kratos_external_provider import (
ExternalIdpRequirer
)
class KratosCharm(CharmBase):
def __init__(self, *args):
# ...
self.external_idp_requirer = ExternalIdpRequirer(self)
self.framework.observe(
self.external_idp_provider.on.client_config_changed, self._on_client_config_changed
)
def _on_client_config_changed(self, event):
self._configure(event)
self.external_provider.set_relation_registered_provider(
some_redirect_uri, event.provider_id, event.relation_id
)
```
"""
import base64
import hashlib
import inspect
import json
import logging
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Type
import jsonschema
from ops.charm import (
CharmBase,
RelationChangedEvent,
RelationDepartedEvent,
RelationEvent,
RelationJoinedEvent,
)
from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents
from ops.model import Relation, TooManyRelatedAppsError
# The unique Charmhub library identifier, never change it
LIBID = "33040051de7f43a8bb43349f2b037dfc"
# Increment this major API version when introducing breaking changes
LIBAPI = 0
# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 11
PYDEPS = ["jsonschema"]
DEFAULT_RELATION_NAME = "kratos-external-idp"
logger = logging.getLogger(__name__)
PROVIDER_PROVIDERS_JSON_SCHEMA = {
"type": "array",
"items": {
"anyOf": [
{
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": [
"generic",
"google",
"facebook",
"microsoft",
"github",
"apple",
"gitlab",
"auth0",
"slack",
"spotify",
"discord",
"twitch",
"netid",
"yander",
"vk",
"dingtalk",
],
},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
"secret_backend": {"type": "string"},
"issuer_url": {"type": "string"},
"tenant_id": {"type": "string"},
"private_key": {"type": "string"},
"private_key_id": {"type": "string"},
"scope": {"type": "string"},
"team_id": {"type": "string"},
"provider_id": {"type": "string"},
"label": {"type": "string"},
"jsonnet_mapper": {"type": "string"},
},
"additionalProperties": True,
},
],
},
}
PROVIDER_JSON_SCHEMA = {
"type": "object",
"properties": {
"providers": PROVIDER_PROVIDERS_JSON_SCHEMA,
},
}
REQUIRER_PROVIDERS_JSON_SCHEMA = {
"type": "array",
"items": {
"anyOf": [
{
"type": "object",
"properties": {
"provider_id": {"type": "string"},
"redirect_uri": {"type": "string"},
},
"additionalProperties": True,
},
]
},
}
REQUIRER_JSON_SCHEMA = {
"type": "object",
"properties": {
"providers": REQUIRER_PROVIDERS_JSON_SCHEMA,
},
}
class InvalidConfigError(Exception):
"""Internal exception that is raised if the charm config is not valid."""
class DataValidationError(RuntimeError):
"""Raised when data validation fails on relation data."""
def _load_data(data: Dict, schema: Dict) -> Dict:
"""Parses nested fields and checks whether `data` matches `schema`."""
if "providers" not in data:
return {"providers": []}
data = dict(data)
try:
data["providers"] = json.loads(data["providers"])
except json.JSONDecodeError as e:
raise DataValidationError(f"Failed to decode relation json: {e}")
_validate_data(data, schema)
return data
def _dump_data(data: Dict, schema: Dict) -> Dict:
_validate_data(data, schema)
data = dict(data)
try:
data["providers"] = json.dumps(data["providers"])
except json.JSONDecodeError as e:
raise DataValidationError(f"Failed to encode relation json: {e}")
return data
def _validate_data(data: Dict, schema: Dict) -> None:
"""Checks whether `data` matches `schema`.
Will raise DataValidationError if the data is not valid, else return None.
"""
try:
jsonschema.validate(instance=data, schema=schema)
except jsonschema.ValidationError as e:
raise DataValidationError(data, schema) from e
class BaseProviderConfigHandler:
"""The base class for parsing a provider's config."""
mandatory_fields = {"provider", "client_id", "secret_backend"}
optional_fields = {"provider_id", "jsonnet_mapper", "label", "scope"}
excluded_fields = {"enabled"}
default_scope = "profile email address phone"
providers: List[str] = []
@classmethod
def validate_config(cls, config: Mapping) -> Dict:
"""Validate and sanitize the user provided config."""
config_keys = set(config.keys())
provider = config["provider"]
if provider not in cls.providers:
raise ValueError(f"Invalid provider, allowed providers are: {cls.providers}")
for key in cls.mandatory_fields:
if not config.get(key, None):
raise InvalidConfigError(
f"Missing required configuration '{key}' for provider '{config['provider']}'"
)
config_keys.remove(key)
for key in cls.optional_fields:
config_keys.discard(key)
if config["secret_backend"] not in ["relation", "secret", "vault"]:
raise InvalidConfigError(
f"Invalid value {config['secret_backend']} for `secret_backend` "
"allowed values are: ['relation', 'secret', 'vault']"
)
for key in config_keys:
if key not in cls.excluded_fields:
logger.warn(f"Invalid config '{key}' for provider '{provider}' will be ignored")
return {key: value for key, value in config.items() if key not in config_keys}
@classmethod
def handle_config(cls, config: Mapping) -> List:
"""Validate the config and transform it in the relation databag expected format."""
config = cls.validate_config(config)
return cls.parse_config(config)
@classmethod
def parse_config(cls, config: Dict) -> List:
"""Parse the user provided config into the relation databag expected format."""
ret = {
"client_id": config["client_id"],
"provider": config["provider"],
"secret_backend": config["secret_backend"],
"scope": config.get("scope", cls.default_scope),
}
ret.update({k: config[k] for k in cls.optional_fields if k in config})
ret.update(cls._parse_provider_config(config))
return [ret]
@classmethod
def _parse_provider_config(cls, config: Dict) -> Dict:
"""Create the provider specific config."""
raise NotImplementedError()
class GenericConfigHandler(BaseProviderConfigHandler):
"""The class for parsing a 'generic' provider's config."""
mandatory_fields = BaseProviderConfigHandler.mandatory_fields | {"client_secret", "issuer_url"}
providers = ["generic", "auth0"]
@classmethod
def _parse_provider_config(cls, config: Dict) -> Dict:
return {
"client_secret": config["client_secret"],
"issuer_url": config["issuer_url"],
}
class SocialConfigHandler(BaseProviderConfigHandler):
"""The class for parsing a social provider's config."""
mandatory_fields = BaseProviderConfigHandler.mandatory_fields | {"client_secret"}
providers = [
"google",
"facebook",
"gitlab",
"slack",
"spotify",
"discord",
"twitch",
"netid",
"yander",
"vk",
"dingtalk",
]
@classmethod
def _parse_provider_config(cls, config: Dict) -> Dict:
return {
"client_secret": config["client_secret"],
}
class MicrosoftConfigHandler(SocialConfigHandler):
"""The class for parsing a 'microsoft' provider's config."""
mandatory_fields = SocialConfigHandler.mandatory_fields | {
"microsoft_tenant_id",
}
providers = ["microsoft"]
@classmethod
def _parse_provider_config(cls, config: Dict) -> Dict:
return {
"client_secret": config["client_secret"],
"tenant_id": config["microsoft_tenant_id"],
}
@classmethod
def _parse_relation_data(cls, data: Dict) -> Dict:
return {
"client_secret": data["client_secret"],
"tenant_id": data["tenant_id"],
}
class GithubConfigHandler(SocialConfigHandler):
"""The class for parsing a 'github' provider's config."""
default_scope = "user:email"
providers = ["github"]
class AppleConfigHandler(BaseProviderConfigHandler):
"""The class for parsing an 'apple' provider's config."""
mandatory_fields = BaseProviderConfigHandler.mandatory_fields | {
"apple_team_id",
"apple_private_key_id",
"apple_private_key",
}
_secret_fields = ["private_key"]
providers = ["apple"]
@classmethod
def _parse_provider_config(cls, config: Dict) -> Dict:
return {
"team_id": config["apple_team_id"],
"private_key_id": config["apple_private_key_id"],
"private_key": config["apple_private_key"],
}
_config_handlers = [
GenericConfigHandler,
SocialConfigHandler,
MicrosoftConfigHandler,
GithubConfigHandler,
AppleConfigHandler,
]
allowed_providers = {
provider: handler for handler in _config_handlers for provider in handler.providers
}
def get_provider_config_handler(config: Mapping) -> Type[BaseProviderConfigHandler]:
"""Get the config handler for this provider."""
provider = config.get("provider")
if provider not in allowed_providers:
raise InvalidConfigError(
"Required configuration 'provider' MUST be one of the following: "
+ ", ".join(allowed_providers)
)
return allowed_providers[provider]
class RelationReadyEvent(EventBase):
"""Event to notify the charm that the relation is ready."""
def snapshot(self) -> Dict:
"""Save event."""
return {}
def restore(self, snapshot: Dict) -> None:
"""Restore event."""
pass
class RedirectURIChangedEvent(EventBase):
"""Event to notify the charm that the redirect_uri changed."""
def __init__(self, handle: Handle, redirect_uri: str) -> None:
super().__init__(handle)
self.redirect_uri = redirect_uri
def snapshot(self) -> Dict:
"""Save redirect_uri."""
return {"redirect_uri": self.redirect_uri}
def restore(self, snapshot: Dict) -> None:
"""Restore redirect_uri."""
self.redirect_uri = snapshot["redirect_uri"]
class ExternalIdpProviderEvents(ObjectEvents):
"""Event descriptor for events raised by `ExternalIdpProvider`."""
ready = EventSource(RelationReadyEvent)
redirect_uri_changed = EventSource(RedirectURIChangedEvent)
class ExternalIdpProvider(Object):
"""Forward client configurations to Identity Broker."""
on = ExternalIdpProviderEvents()
def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None:
super().__init__(charm, relation_name)
self._charm = charm
self._relation_name = relation_name
events = self._charm.on[relation_name]
self.framework.observe(events.relation_joined, self._on_provider_endpoint_relation_joined)
self.framework.observe(
events.relation_changed, self._on_provider_endpoint_relation_changed
)
self.framework.observe(
events.relation_departed, self._on_provider_endpoint_relation_departed
)
def _on_provider_endpoint_relation_joined(self, event: RelationJoinedEvent) -> None:
self.on.ready.emit()
def _on_provider_endpoint_relation_changed(self, event: RelationChangedEvent) -> None:
if not event.app:
return
data = event.relation.data[event.app]
data = _load_data(data, REQUIRER_JSON_SCHEMA)
providers = data["providers"]
if len(providers) == 0:
return
redirect_uri = providers[0].get("redirect_uri")
self.on.redirect_uri_changed.emit(redirect_uri=redirect_uri)
def _on_provider_endpoint_relation_departed(self, event: RelationDepartedEvent) -> None:
self.on.redirect_uri_changed.emit(redirect_uri="")
def is_ready(self) -> bool:
"""Checks if the relation is ready."""
return self._charm.model.get_relation(self._relation_name) is not None
def create_provider(self, config: Mapping) -> None:
"""Use the configuration to create the relation databag."""
if not self._charm.unit.is_leader():
return
config = self._handle_config(config)
return self._set_provider_data(config)
def remove_provider(self) -> None:
"""Remove the provider config to the relation databag."""
if not self._charm.unit.is_leader():
return
# Do we need to iterate on the relations? There should never be more
# than one
for relation in self._charm.model.relations[self._relation_name]:
relation.data[self._charm.app].clear()
def get_redirect_uri(self, relation_id: Optional[int] = None) -> Optional[str]:
"""Get the kratos client's redirect_uri."""
if not self.model.unit.is_leader():
return None
try:
relation = self.model.get_relation(
relation_name=self._relation_name, relation_id=relation_id
)
except TooManyRelatedAppsError:
raise RuntimeError("More than one relations are defined. Please provide a relation_id")
if not relation or not relation.app:
return None
data = relation.data[relation.app]
data = _load_data(data, REQUIRER_JSON_SCHEMA)
providers = data["providers"]
if len(providers) == 0:
return None
return providers[0].get("redirect_uri")
def validate_provider_config(self, config: Mapping) -> None:
"""Validate the provider config.
Raises InvalidConfigError if config is invalid.
"""
self._validate_config(config)
def _handle_config(self, config: Mapping) -> List:
handler = get_provider_config_handler(config)
return handler.handle_config(config)
def _validate_config(self, config: Mapping) -> None:
handler = get_provider_config_handler(config)
handler.validate_config(config)
def _set_provider_data(self, provider_config: List) -> None:
self._create_secrets(provider_config)
# Do we need to iterate on the relations? There should never be more
# than one
for relation in self._charm.model.relations[self._relation_name]:
relation.data[self._charm.app]["providers"] = json.dumps(provider_config)
def _create_secrets(self, provider_config: List) -> None:
for conf in provider_config:
backend = conf["secret_backend"]
if backend == "relation":
pass
elif backend == "secret":
raise NotImplementedError()
elif backend == "vault":
raise NotImplementedError()
else:
raise ValueError(f"Invalid backend: {backend}")
@dataclass
class Provider:
"""Class for describing an external provider."""
client_id: str
provider: str
relation_id: Optional[str] = None
scope: str = "profile email address phone"
label: Optional[str] = None
client_secret: Optional[str] = None
issuer_url: Optional[str] = None
tenant_id: Optional[str] = None
microsoft_tenant: Optional[str] = None
team_id: Optional[str] = None
private_key_id: Optional[str] = None
private_key: Optional[str] = None
jsonnet_mapper: Optional[str] = None
id: Optional[str] = None
@property
def provider_id(self) -> str:
"""Returns a unique ID for the client credentials of the provider."""
if self.id:
return self.id
if self.issuer_url:
id = hashlib.sha1(f"{self.client_id}_{self.issuer_url}".encode()).hexdigest()
elif self.get_microsoft_tenant():
id = hashlib.sha1(f"{self.client_id}_{self.tenant_id}".encode()).hexdigest()
else:
id = hashlib.sha1(self.client_id.encode()).hexdigest()
return f"{self.provider}_{id}"
@provider_id.setter
def provider_id(self, val) -> None:
self.id = val
def get_scope(self) -> list:
if isinstance(self.scope, str):
return self.scope.split(" ")
elif isinstance(self.scope, list):
return self.scope
else:
raise ValueError(f"scope must be `list` or `str`, but `{type(self.scope)}` provided")
def get_microsoft_tenant(self) -> str:
return self.tenant_id or self.microsoft_tenant
def config(self) -> Dict:
"""Generate Kratos config for this provider."""
ret = {
"id": self.provider_id,
"client_id": self.client_id,
"provider": self.provider,
"label": self.label or self.provider,
"client_secret": self.client_secret,
"issuer_url": self.issuer_url,
"scope": self.get_scope(),
"mapper_url": (
f"base64://{base64.b64encode(self.jsonnet_mapper.encode()).decode()}"
if self.jsonnet_mapper
else None
),
"microsoft_tenant": self.get_microsoft_tenant(),
"apple_team_id": self.team_id,
"apple_private_key_id": self.private_key_id,
"apple_private_key": self.private_key,
}
return {k: v for k, v in ret.items() if v}
@classmethod
def from_dict(cls, dic: Dict) -> "Provider":
"""Generate Provider instance from dict."""
if provider_id := dic.get("provider_id"):
dic["id"] = provider_id
return cls(**{k: v for k, v in dic.items() if k in inspect.signature(cls).parameters})
class ClientConfigChangedEvent(EventBase):
"""Event to notify the charm that a provider's client config changed."""
def __init__(self, handle: Handle, provider: Provider) -> None:
super().__init__(handle)
self.client_id = provider.client_id
self.provider = provider.provider
self.provider_id = provider.provider_id
self.relation_id = provider.relation_id
def snapshot(self) -> Dict:
"""Save event."""
return {
"client_id": self.client_id,
"provider": self.provider,
"provider_id": self.provider_id,
"relation_id": self.relation_id,
}
def restore(self, snapshot: Dict) -> None:
"""Restore event."""
self.client_id = snapshot["client_id"]
self.provider = snapshot["provider"]
self.provider_id = snapshot["provider_id"]
self.relation_id = snapshot["relation_id"]
class ClientConfigRemovedEvent(EventBase):
"""Event to notify the charm that a provider's client config was removed."""
def __init__(self, handle: Handle, relation_id: str) -> None:
super().__init__(handle)
self.relation_id = relation_id
def snapshot(self) -> Dict:
"""Save event."""
return {
"relation_id": self.relation_id,
}
def restore(self, snapshot: Dict) -> None:
"""Restore event."""
self.relation_id = snapshot["relation_id"]
class ExternalIdpRequirerEvents(ObjectEvents):
"""Event descriptor for events raised by `ExternalIdpRequirerEvents`."""
client_config_changed = EventSource(ClientConfigChangedEvent)
client_config_removed = EventSource(ClientConfigRemovedEvent)
class ExternalIdpRequirer(Object):
"""Receive the External Idp configurations for Kratos."""
on = ExternalIdpRequirerEvents()
def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None:
super().__init__(charm, relation_name)
self._charm = charm
self._relation_name = relation_name
events = self._charm.on[relation_name]
self.framework.observe(
events.relation_changed, self._on_provider_endpoint_relation_changed
)
self.framework.observe(
events.relation_departed, self._on_provider_endpoint_relation_changed
)
def _on_provider_endpoint_relation_changed(self, event: RelationEvent) -> None:
if not event.app:
return
data = event.relation.data[event.app]
data = _load_data(data, PROVIDER_JSON_SCHEMA)
providers = data["providers"]
if len(providers) == 0:
self.on.client_config_removed.emit(event.relation.id)
return
p = self._get_provider(providers[0], event.relation)
self.on.client_config_changed.emit(p)
def set_relation_registered_provider(
self, redirect_uri: str, provider_id: str, relation_id: int
) -> None:
"""Update the relation databag."""
if not self._charm.unit.is_leader():
return
data = {
"providers": [
{
"redirect_uri": redirect_uri,
"provider_id": provider_id,
}
]
}
data = _dump_data(data, REQUIRER_JSON_SCHEMA)
relation = self.model.get_relation(
relation_name=self._relation_name, relation_id=relation_id
)
if not relation:
return
relation.data[self.model.app].update(data)
def remove_relation_registered_provider(self, relation_id: int) -> None:
"""Delete the provider info from the databag."""
if not self._charm.unit.is_leader():
return
relation = self.model.get_relation(
relation_name=self._relation_name, relation_id=relation_id
)
if not relation:
return
relation.data[self.model.app].clear()
def get_providers(self) -> List:
"""Iterate over the relations and fetch all providers."""
providers = []
# For each relation get the client credentials and compile them into a
# single object
for relation in self.model.relations[self._relation_name]:
if not relation.app:
continue
data = relation.data[relation.app]
data = _load_data(data, PROVIDER_JSON_SCHEMA)
for p in data["providers"]:
provider = self._get_provider(p, relation)
providers.append(provider)
return providers
def _get_provider(self, provider: Dict, relation: Relation) -> Provider:
provider = self._extract_secrets(provider)
provider["relation_id"] = relation.id
provider = Provider.from_dict(provider)
return provider
def _extract_secrets(self, data: Dict) -> Dict:
backend = data["secret_backend"]
if backend == "relation":
pass
elif backend == "secret":
raise NotImplementedError()
elif backend == "vault":
raise NotImplementedError()
else:
raise ValueError(f"Invalid backend: {backend}")
return data