diff --git a/mistralclient/api/httpclient.py b/mistralclient/api/httpclient.py index 9492c98a..00fc3964 100644 --- a/mistralclient/api/httpclient.py +++ b/mistralclient/api/httpclient.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import copy import os @@ -21,8 +22,21 @@ import requests import logging -osprofiler_web = importutils.try_import("osprofiler.web") +AUTH_TOKEN = 'auth_token' +CACERT = 'cacert' +CERT_FILE = 'cert' +CERT_KEY = 'key' +INSECURE = 'insecure' +PROJECT_ID = 'project_id' +TARGET_AUTH_TOKEN = 'target_auth_token' +TARGET_AUTH_URI = 'target_auth_url' +TARGET_PROJECT_ID = 'target_project_id' +TARGET_USER_ID = 'target_user_id' +TARGET_SERVICE_CATALOG = 'target_service_catalog' +USER_ID = 'user_id' + +osprofiler_web = importutils.try_import("osprofiler.web") LOG = logging.getLogger(__name__) @@ -39,13 +53,16 @@ def log_request(func): class HTTPClient(object): def __init__(self, base_url, **kwargs): self.base_url = base_url - self.auth_token = kwargs.get('auth_token', None) - self.project_id = kwargs.get('project_id', None) - self.user_id = kwargs.get('user_id', None) - self.target_auth_token = kwargs.get('target_auth_token', None) - self.target_auth_url = kwargs.get('target_auth_url', None) - self.cacert = kwargs.get('cacert', None) - self.insecure = kwargs.get('insecure', False) + self.auth_token = kwargs.get(AUTH_TOKEN) + self.project_id = kwargs.get(PROJECT_ID) + self.user_id = kwargs.get(USER_ID) + self.target_auth_token = kwargs.get(TARGET_AUTH_TOKEN) + self.target_auth_uri = kwargs.get(TARGET_AUTH_URI) + self.target_user_id = kwargs.get(TARGET_USER_ID) + self.target_project_id = kwargs.get(TARGET_PROJECT_ID) + self.target_service_catalog = kwargs.get(TARGET_SERVICE_CATALOG) + self.cacert = kwargs.get(CACERT) + self.insecure = kwargs.get(INSECURE, False) self.ssl_options = {} if self.base_url.startswith('https'): @@ -65,7 +82,10 @@ class HTTPClient(object): else: self.ssl_options['verify'] = True - self.ssl_options['cert'] = (kwargs.get('cert'), kwargs.get('key')) + self.ssl_options['cert'] = ( + kwargs.get(CERT_FILE), + kwargs.get(CERT_KEY) + ) @log_request def get(self, url, headers=None): @@ -107,30 +127,31 @@ class HTTPClient(object): if not headers: headers = {} - auth_token = headers.get('x-auth-token', self.auth_token) - if auth_token: - headers['x-auth-token'] = auth_token + if self.auth_token: + headers['x-auth-token'] = self.auth_token - project_id = headers.get('X-Project-Id', self.project_id) - if project_id: - headers['X-Project-Id'] = project_id + if self.project_id: + headers['X-Project-Id'] = self.project_id - user_id = headers.get('X-User-Id', self.user_id) - if user_id: - headers['X-User-Id'] = user_id + if self.user_id: + headers['X-User-Id'] = self.user_id - target_auth_token = headers.get( - 'X-Target-Auth-Token', - self.target_auth_token - ) + if self.target_auth_token: + headers['X-Target-Auth-Token'] = self.target_auth_token - if target_auth_token: - headers['X-Target-Auth-Token'] = target_auth_token + if self.target_auth_uri: + headers['X-Target-Auth-Uri'] = self.target_auth_uri - target_auth_url = headers.get('X-Target-Auth-Uri', - self.target_auth_url) - if target_auth_url: - headers['X-Target-Auth-Uri'] = target_auth_url + if self.target_project_id: + headers['X-Target-Project-Id'] = self.target_project_id + + if self.target_user_id: + headers['X-Target-User-Id'] = self.target_user_id + + if self.target_service_catalog: + headers['X-Target-Service-Catalog'] = base64.b64encode( + self.target_service_catalog.encode('utf-8') + ) if osprofiler_web: # Add headers for osprofiler. diff --git a/mistralclient/api/v2/client.py b/mistralclient/api/v2/client.py index 062bf878..67589097 100644 --- a/mistralclient/api/v2/client.py +++ b/mistralclient/api/v2/client.py @@ -41,22 +41,17 @@ class Client(object): def __init__(self, auth_type='keystone', **kwargs): req = copy.deepcopy(kwargs) mistral_url = req.get('mistral_url') - auth_url = req.get('auth_url') - auth_token = req.get('auth_token') - project_id = req.get('project_id') - user_id = req.get('user_id') profile = req.get('profile') if mistral_url and not isinstance(mistral_url, six.string_types): raise RuntimeError('Mistral url should be a string.') - if auth_url and not auth_token: - auth_handler = auth.get_auth_handler(auth_type) - auth_response = auth_handler.authenticate(req) or {} - mistral_url = auth_response.get('mistral_url') or mistral_url - req['auth_token'] = auth_response.get('token') - req['project_id'] = auth_response.get('project_id') or project_id - req['user_id'] = auth_response.get('user_id') or user_id + auth_handler = auth.get_auth_handler(auth_type) + auth_response = auth_handler.authenticate(req) or {} + + req.update(auth_response) + + mistral_url = auth_response.get('mistral_url') or mistral_url if not mistral_url: mistral_url = _DEFAULT_MISTRAL_URL diff --git a/mistralclient/auth/keystone.py b/mistralclient/auth/keystone.py index 6c8bb968..313d7c4c 100644 --- a/mistralclient/auth/keystone.py +++ b/mistralclient/auth/keystone.py @@ -13,6 +13,9 @@ # limitations under the License. from mistralclient import auth +from oslo_serialization import jsonutils + +import mistralclient.api.httpclient as api def _get_keystone_client(auth_url): @@ -68,6 +71,8 @@ class KeystoneAuthHandler(auth.AuthHandler): 'Only user name or user id should be set' ) + auth_response = {} + if auth_url: keystone_client = _get_keystone_client(auth_url) @@ -85,9 +90,23 @@ class KeystoneAuthHandler(auth.AuthHandler): ) keystone.authenticate() - auth_token = keystone.auth_token - user_id = keystone.user_id - project_id = keystone.project_id + + auth_response.update({ + api.AUTH_TOKEN: keystone.auth_token, + api.PROJECT_ID: keystone.project_id, + api.USER_ID: keystone.user_id, + }) + + if not mistral_url: + try: + mistral_url = keystone.service_catalog.url_for( + service_type=service_type, + endpoint_type=endpoint_type + ) + except Exception: + mistral_url = None + + auth_response['mistral_url'] = mistral_url if target_auth_url: target_keystone_client = _get_keystone_client(target_auth_url) @@ -107,20 +126,14 @@ class KeystoneAuthHandler(auth.AuthHandler): target_keystone.authenticate() - if not mistral_url: - try: - mistral_url = keystone.service_catalog.url_for( - service_type=service_type, - endpoint_type=endpoint_type + auth_response.update({ + api.TARGET_AUTH_TOKEN: target_keystone.auth_token, + api.TARGET_PROJECT_ID: target_keystone.project_id, + api.TARGET_USER_ID: target_keystone.user_id, + api.TARGET_AUTH_URI: target_auth_url, + api.TARGET_SERVICE_CATALOG: jsonutils.dumps( + target_keystone.auth_ref ) - except Exception: - mistral_url = None + }) - return { - 'mistral_url': mistral_url, - 'token': auth_token, - 'project_id': target_project_id if target_auth_url else project_id, - 'user_id': target_user_id if target_auth_url else user_id, - 'target_auth_token': target_auth_token, - 'target_auth_url': target_auth_url - } + return auth_response diff --git a/mistralclient/tests/unit/test_client.py b/mistralclient/tests/unit/test_client.py index a9ea5c64..6ee4b497 100644 --- a/mistralclient/tests/unit/test_client.py +++ b/mistralclient/tests/unit/test_client.py @@ -116,6 +116,44 @@ class BaseClientTests(base.BaseTestCase): keystone_client_instance.user_id, kwargs['user_id'] ) + @mock.patch('keystoneclient.v3.client.Client') + @mock.patch('mistralclient.api.httpclient.HTTPClient') + def test_target_parameters_processed( + self, + http_client_mock, + keystone_client_mock + ): + keystone_client_instance = self.setup_keystone_mock( + keystone_client_mock + ) + + url_for = mock.Mock(return_value='http://mistral_host:8989/v2') + keystone_client_instance.service_catalog.url_for = url_for + + client.client( + target_username='tmistral', + target_project_name='tmistralp', + target_auth_url=AUTH_HTTP_URL_v3 + ) + + self.assertTrue(http_client_mock.called) + mistral_url_for_http = http_client_mock.call_args[0][0] + kwargs = http_client_mock.call_args[1] + self.assertEqual(MISTRAL_HTTP_URL, mistral_url_for_http) + + expected_values = { + 'target_project_id': keystone_client_instance.project_id, + 'target_auth_token': keystone_client_instance.auth_token, + 'target_user_id': keystone_client_instance.user_id, + 'target_auth_url': AUTH_HTTP_URL_v3, + 'target_project_name': 'tmistralp', + 'target_username': 'tmistral', + 'target_service_catalog': '"{}"' + } + + for key in expected_values: + self.assertEqual(expected_values[key], kwargs[key]) + @mock.patch('keystoneclient.v3.client.Client') @mock.patch('mistralclient.api.httpclient.HTTPClient') def test_mistral_url_https_insecure(self, http_client_mock, diff --git a/mistralclient/tests/unit/test_httpclient.py b/mistralclient/tests/unit/test_httpclient.py index bbd94cc5..5cc1f7bc 100644 --- a/mistralclient/tests/unit/test_httpclient.py +++ b/mistralclient/tests/unit/test_httpclient.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import copy import uuid @@ -135,6 +136,9 @@ class HTTPClientTest(base.BaseTestCase): def test_get_request_options_with_headers_for_get(self): target_auth_url = str(uuid.uuid4()) target_auth_token = str(uuid.uuid4()) + target_user_id = 'target_user' + target_project_id = 'target_project' + target_service_catalog = 'this should be there' target_client = httpclient.HTTPClient( API_BASE_URL, @@ -142,7 +146,10 @@ class HTTPClientTest(base.BaseTestCase): project_id=PROJECT_ID, user_id=USER_ID, target_auth_url=target_auth_url, - target_auth_token=target_auth_token + target_auth_token=target_auth_token, + target_project_id=target_project_id, + target_user_id=target_user_id, + target_service_catalog=target_service_catalog ) target_client.get(API_URL) @@ -150,6 +157,10 @@ class HTTPClientTest(base.BaseTestCase): expected_options = copy.deepcopy(EXPECTED_REQ_OPTIONS) expected_options["headers"]["X-Target-Auth-Uri"] = target_auth_url expected_options["headers"]["X-Target-Auth-Token"] = target_auth_token + expected_options["headers"]["X-Target-User-Id"] = target_user_id + expected_options["headers"]["X-Target-Project-Id"] = target_project_id + catalog = base64.b64encode(target_service_catalog.encode('utf-8')) + expected_options["headers"]["X-Target-Service-Catalog"] = catalog requests.get.assert_called_with( EXPECTED_URL,