diff --git a/sap/adt/core.py b/sap/adt/core.py index f5b8bb29..502345c8 100644 --- a/sap/adt/core.py +++ b/sap/adt/core.py @@ -101,7 +101,8 @@ class Connection: """ # pylint: disable=too-many-arguments - def __init__(self, host, client, user, password, port=None, ssl=True, verify=True, ssl_server_cert=None): + def __init__(self, host, client, user, password, port=None, ssl=True, verify=True, ssl_server_cert=None, + token_url=None, client_id=None, client_secret=None): """Parameters: - host: string host name - client: string SAP client @@ -124,10 +125,13 @@ def __init__(self, host, client, user, password, port=None, ssl=True, verify=Tru port=port, user=user, password=password, - saml2=False, + saml2=None if (token_url and client_id and client_secret) else False, client=client, verify=verify, ssl_server_cert=ssl_server_cert, + token_url=token_url, + client_id=client_id, + client_secret=client_secret, # This must be the default login path because newer ABAP systems # did not return cookies and CSRF token with the old default login # path (GET /sap/bc/adt/discovery) and thus did not work with diff --git a/sap/cli/__init__.py b/sap/cli/__init__.py index 2f0e8106..ef21e733 100644 --- a/sap/cli/__init__.py +++ b/sap/cli/__init__.py @@ -129,7 +129,10 @@ def adt_connection_from_args(args): return sap.adt.Connection( args.ashost, args.client, args.user, args.password, port=args.port, ssl=args.ssl, verify=args.verify, - ssl_server_cert=args.ssl_server_cert) + ssl_server_cert=args.ssl_server_cert, + token_url=getattr(args, 'token_url', None), + client_id=getattr(args, 'client_id', None), + client_secret=getattr(args, 'client_secret', None)) def rfc_connection_from_args(args): @@ -202,6 +205,9 @@ def build_empty_connection_values(): ssl_server_cert=None, user=None, password=None, + token_url=None, + client_id=None, + client_secret=None, ) @@ -285,6 +291,15 @@ def resolve_default_connection_values(args): if not args.password: args.password = os.getenv('SAP_PASSWORD') or config_values.get('password') + if not getattr(args, 'token_url', None): + args.token_url = os.getenv('SAP_TOKEN_URL') or config_values.get('token_url') + + if not getattr(args, 'client_id', None): + args.client_id = os.getenv('SAP_CLIENT_ID') or config_values.get('client_id') + + if not getattr(args, 'client_secret', None): + args.client_secret = os.getenv('SAP_CLIENT_SECRET') or config_values.get('client_secret') + if hasattr(args, 'corrnr') and args.corrnr is None: args.corrnr = os.getenv('SAP_CORRNR') diff --git a/sap/cli/_entry.py b/sap/cli/_entry.py index 5958b76c..57e94a37 100644 --- a/sap/cli/_entry.py +++ b/sap/cli/_entry.py @@ -16,6 +16,7 @@ import sap.rfc from sap.config import ConfigFile from sap.http import TimedOutRequestError as HttpTimedOutRequestError +from sap.http.oauth import get_cached_token, get_cached_refresh_token from sap.odata.errors import TimedOutRequestError as ODataTimedOutRequestError # pylint: disable=invalid-name @@ -157,7 +158,15 @@ def parse_command_line(argv): if not args.user: args.user = input('Login:') - if not args.password: + token_url = getattr(args, 'token_url', None) + client_id = getattr(args, 'client_id', None) + has_valid_token = ( + token_url and client_id and ( + get_cached_token(token_url, client_id) or + get_cached_refresh_token(token_url, client_id) + ) + ) + if not args.password and not has_valid_token: args.password = getpass.getpass() return args diff --git a/sap/config.py b/sap/config.py index 4943f7d7..0b8c7658 100644 --- a/sap/config.py +++ b/sap/config.py @@ -23,6 +23,7 @@ class SAPCliConfigError(SAPCliError): 'ashost', 'sysnr', 'client', 'port', 'ssl', 'ssl_verify', 'ssl_server_cert', 'mshost', 'msserv', 'sysid', 'group', 'snc_qop', 'snc_myname', 'snc_partnername', 'snc_lib', + 'token_url', 'client_id', 'client_secret', ) USER_FIELDS = ( diff --git a/sap/http/client.py b/sap/http/client.py index e521579c..a9da20b3 100644 --- a/sap/http/client.py +++ b/sap/http/client.py @@ -2,7 +2,7 @@ import requests import requests.exceptions -from requests.auth import HTTPBasicAuth +from requests.auth import AuthBase, HTTPBasicAuth from sap import get_logger, config_get from sap.http.errors import ( @@ -10,6 +10,18 @@ UnauthorizedError, TimedOutRequestError, ) +from sap.http.oauth import get_token + + +class BearerAuth(AuthBase): + """Requests auth handler that injects an OAuth 2.0 Bearer token.""" + + def __init__(self, token): + self._token = token + + def __call__(self, r): + r.headers['Authorization'] = f'Bearer {self._token}' + return r def build_query_args(client=None, saml2=None): @@ -63,7 +75,10 @@ def __init__(self, verify=None, ssl_server_cert=None, login_path='', - login_method='HEAD' + login_method='HEAD', + token_url=None, + client_id=None, + client_secret=None, ): self.ssl = ssl @@ -91,7 +106,11 @@ def __init__(self, self.timeout = config_get('http_timeout') - self._auth = HTTPBasicAuth(user, password) + if token_url and client_id and client_secret: + token = get_token(token_url, client_id, client_secret, user=user, password=password) + self._auth = BearerAuth(token) + else: + self._auth = HTTPBasicAuth(user, password) self.error_handlers = [default_http_error_handler] self._connection_error_handler = None diff --git a/sap/http/oauth.py b/sap/http/oauth.py new file mode 100644 index 00000000..caf79f03 --- /dev/null +++ b/sap/http/oauth.py @@ -0,0 +1,129 @@ +"""OAuth 2.0 password grant flow with token caching for BTP Steampunk.""" + +import json +import os +import time +from pathlib import Path + +import requests + +TOKEN_CACHE_PATH = Path('~/.sapcli/tokens.json').expanduser() +REFRESH_MARGIN = 60 + + +# --------------------------------------------------------------------------- +# Token cache +# --------------------------------------------------------------------------- + +def _load_token_cache(): + try: + with open(TOKEN_CACHE_PATH, 'r', encoding='utf-8') as f: + return json.load(f) + except (OSError, json.JSONDecodeError): + return {} + + +def _save_token_cache(cache): + TOKEN_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) + fd = os.open(TOKEN_CACHE_PATH, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, 'w', encoding='utf-8') as f: + json.dump(cache, f, indent=2) + + +def _cache_key(token_url, client_id): + return f'{token_url}|{client_id}' + + +def get_cached_token(token_url, client_id): + cache = _load_token_cache() + entry = cache.get(_cache_key(token_url, client_id)) + if not entry: + return None + if time.time() > entry.get('expires_at', 0) - REFRESH_MARGIN: + return None + return entry['access_token'] + + +def get_cached_refresh_token(token_url, client_id): + cache = _load_token_cache() + entry = cache.get(_cache_key(token_url, client_id)) + if not entry: + return None + return entry.get('refresh_token') + + +def save_token_response(token_url, client_id, token_response): + cache = _load_token_cache() + expires_in = token_response.get('expires_in', 3600) + cache[_cache_key(token_url, client_id)] = { + 'access_token': token_response['access_token'], + 'refresh_token': token_response.get('refresh_token'), + 'expires_at': time.time() + expires_in, + } + _save_token_cache(cache) + + +# --------------------------------------------------------------------------- +# Token refresh +# --------------------------------------------------------------------------- + +def refresh_access_token(token_url, client_id, client_secret, refresh_token): + response = requests.post( + token_url.rstrip('/') + '/oauth/token', + auth=(client_id, client_secret), + data={'grant_type': 'refresh_token', 'refresh_token': refresh_token}, + timeout=30, + ) + if not response.ok: + return None + token_data = response.json() + save_token_response(token_url, client_id, token_data) + return token_data['access_token'] + + +# --------------------------------------------------------------------------- +# Interactive password grant +# --------------------------------------------------------------------------- + +def fetch_token_with_credentials(token_url, client_id, client_secret, user, password): + """Obtain a Bearer token via OAuth 2.0 password grant using provided credentials.""" + + response = requests.post( + token_url.rstrip('/') + '/oauth/token', + auth=(client_id, client_secret), + data={ + 'grant_type': 'password', + 'username': user, + 'password': password, + }, + timeout=30, + ) + + if not response.ok: + raise RuntimeError( + f'OAuth login failed ({response.status_code}): {response.text}' + ) + + token_data = response.json() + save_token_response(token_url, client_id, token_data) + return token_data['access_token'] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def get_token(token_url, client_id, client_secret, user=None, password=None): + """Return a valid Bearer token — from cache, refresh, or credentials grant.""" + + token = get_cached_token(token_url, client_id) + if token: + return token + + refresh_token = get_cached_refresh_token(token_url, client_id) + if refresh_token: + token = refresh_access_token(token_url, client_id, client_secret, refresh_token) + if token: + return token + + return fetch_token_with_credentials(token_url, client_id, client_secret, user, password) diff --git a/test/unit/test_sap_http_oauth.py b/test/unit/test_sap_http_oauth.py new file mode 100644 index 00000000..a24440cc --- /dev/null +++ b/test/unit/test_sap_http_oauth.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 + +import time +import unittest +from unittest.mock import Mock, patch + +from sap.http.client import BearerAuth, HTTPClient +from sap.http.oauth import ( + _cache_key, + fetch_token_with_credentials, + get_cached_token, + get_cached_refresh_token, + get_token, + refresh_access_token, + save_token_response, +) + + +# --------------------------------------------------------------------------- +# BearerAuth +# --------------------------------------------------------------------------- + +class TestBearerAuth(unittest.TestCase): + + def test_adds_authorization_header(self): + auth = BearerAuth('my-token') + request = Mock() + request.headers = {} + + result = auth(request) + + self.assertEqual(request.headers['Authorization'], 'Bearer my-token') + self.assertIs(result, request) + + def test_overwrites_existing_authorization_header(self): + auth = BearerAuth('new-token') + request = Mock() + request.headers = {'Authorization': 'Basic old'} + + auth(request) + + self.assertEqual(request.headers['Authorization'], 'Bearer new-token') + + +# --------------------------------------------------------------------------- +# HTTPClient OAuth init +# --------------------------------------------------------------------------- + +class TestHTTPClientOAuthInit(unittest.TestCase): + + @patch('sap.http.client.get_token', return_value='bearer-token-123') + def test_uses_bearer_auth_when_oauth_params_provided(self, mock_get_token): + client = HTTPClient( + host='btp.example.com', + user='user@sap.com', + password='secret', + token_url='https://auth.example.com', + client_id='my-client-id', + client_secret='my-client-secret', + ) + + mock_get_token.assert_called_once_with( + 'https://auth.example.com', 'my-client-id', 'my-client-secret', + user='user@sap.com', password='secret' + ) + self.assertIsInstance(client._auth, BearerAuth) + + @patch('sap.http.client.get_token') + def test_uses_basic_auth_when_no_token_url(self, mock_get_token): + from requests.auth import HTTPBasicAuth + client = HTTPClient(host='c50.example.com', user='ELBEZI', password='pass') + + mock_get_token.assert_not_called() + self.assertIsInstance(client._auth, HTTPBasicAuth) + + @patch('sap.http.client.get_token') + def test_uses_basic_auth_when_token_url_missing(self, mock_get_token): + from requests.auth import HTTPBasicAuth + client = HTTPClient( + host='c50.example.com', + user='ELBEZI', + password='pass', + client_id='some-id', + client_secret='some-secret', + # token_url deliberately omitted + ) + + mock_get_token.assert_not_called() + self.assertIsInstance(client._auth, HTTPBasicAuth) + + +# --------------------------------------------------------------------------- +# Token cache +# --------------------------------------------------------------------------- + +class TestTokenCache(unittest.TestCase): + + def _make_token_response(self, access_token='access-123', refresh_token='refresh-456', expires_in=3600): + return { + 'access_token': access_token, + 'refresh_token': refresh_token, + 'expires_in': expires_in, + } + + @patch('sap.http.oauth._save_token_cache') + @patch('sap.http.oauth._load_token_cache', return_value={}) + def test_save_token_response_stores_entry(self, mock_load, mock_save): + save_token_response('https://auth.example.com', 'client-id', self._make_token_response()) + + saved = mock_save.call_args[0][0] + key = _cache_key('https://auth.example.com', 'client-id') + self.assertIn(key, saved) + self.assertEqual(saved[key]['access_token'], 'access-123') + self.assertEqual(saved[key]['refresh_token'], 'refresh-456') + self.assertAlmostEqual(saved[key]['expires_at'], time.time() + 3600, delta=5) + + @patch('sap.http.oauth._load_token_cache') + def test_get_cached_token_returns_valid_token(self, mock_load): + key = _cache_key('https://auth.example.com', 'client-id') + mock_load.return_value = { + key: { + 'access_token': 'valid-token', + 'expires_at': time.time() + 3600, + } + } + + token = get_cached_token('https://auth.example.com', 'client-id') + + self.assertEqual(token, 'valid-token') + + @patch('sap.http.oauth._load_token_cache') + def test_get_cached_token_returns_none_when_expired(self, mock_load): + key = _cache_key('https://auth.example.com', 'client-id') + mock_load.return_value = { + key: { + 'access_token': 'expired-token', + 'expires_at': time.time() - 10, # already expired + } + } + + token = get_cached_token('https://auth.example.com', 'client-id') + + self.assertIsNone(token) + + @patch('sap.http.oauth._load_token_cache', return_value={}) + def test_get_cached_token_returns_none_when_missing(self, mock_load): + token = get_cached_token('https://auth.example.com', 'client-id') + self.assertIsNone(token) + + @patch('sap.http.oauth._load_token_cache') + def test_get_cached_refresh_token_returns_value(self, mock_load): + key = _cache_key('https://auth.example.com', 'client-id') + mock_load.return_value = { + key: {'refresh_token': 'my-refresh-token', 'expires_at': time.time() - 1} + } + + refresh = get_cached_refresh_token('https://auth.example.com', 'client-id') + + self.assertEqual(refresh, 'my-refresh-token') + + @patch('sap.http.oauth._load_token_cache', return_value={}) + def test_get_cached_refresh_token_returns_none_when_missing(self, mock_load): + refresh = get_cached_refresh_token('https://auth.example.com', 'client-id') + self.assertIsNone(refresh) + + +# --------------------------------------------------------------------------- +# Token refresh +# --------------------------------------------------------------------------- + +class TestRefreshAccessToken(unittest.TestCase): + + @patch('sap.http.oauth.save_token_response') + @patch('sap.http.oauth.requests.post') + def test_refresh_success(self, mock_post, mock_save): + mock_post.return_value = Mock( + ok=True, + json=lambda: {'access_token': 'new-token', 'expires_in': 3600} + ) + + token = refresh_access_token( + 'https://auth.example.com', 'client-id', 'client-secret', 'old-refresh' + ) + + self.assertEqual(token, 'new-token') + mock_post.assert_called_once_with( + 'https://auth.example.com/oauth/token', + auth=('client-id', 'client-secret'), + data={'grant_type': 'refresh_token', 'refresh_token': 'old-refresh'}, + timeout=30, + ) + mock_save.assert_called_once() + + @patch('sap.http.oauth.requests.post') + def test_refresh_failure_returns_none(self, mock_post): + mock_post.return_value = Mock(ok=False, status_code=401, text='invalid') + + token = refresh_access_token( + 'https://auth.example.com', 'client-id', 'client-secret', 'bad-refresh' + ) + + self.assertIsNone(token) + + +# --------------------------------------------------------------------------- +# Interactive password grant +# --------------------------------------------------------------------------- + +class TestFetchTokenWithCredentials(unittest.TestCase): + + @patch('sap.http.oauth.save_token_response') + @patch('sap.http.oauth.requests.post') + def test_password_grant_success(self, mock_post, mock_save): + mock_post.return_value = Mock( + ok=True, + json=lambda: {'access_token': 'user-token', 'refresh_token': 'r-token', 'expires_in': 43200} + ) + + token = fetch_token_with_credentials( + 'https://auth.example.com', 'client-id', 'client-secret', + 'user@sap.com', 'mypassword' + ) + + self.assertEqual(token, 'user-token') + mock_post.assert_called_once_with( + 'https://auth.example.com/oauth/token', + auth=('client-id', 'client-secret'), + data={ + 'grant_type': 'password', + 'username': 'user@sap.com', + 'password': 'mypassword', + }, + timeout=30, + ) + mock_save.assert_called_once() + + @patch('sap.http.oauth.requests.post') + def test_password_grant_failure_raises(self, mock_post): + mock_post.return_value = Mock(ok=False, status_code=401, text='invalid_grant') + + with self.assertRaises(RuntimeError) as cm: + fetch_token_with_credentials( + 'https://auth.example.com', 'client-id', 'client-secret', + 'user@sap.com', 'wrongpass' + ) + + self.assertIn('401', str(cm.exception)) + + +# --------------------------------------------------------------------------- +# get_token — orchestration +# --------------------------------------------------------------------------- + +class TestGetToken(unittest.TestCase): + + @patch('sap.http.oauth.get_cached_token', return_value='cached-token') + def test_returns_cached_token_without_refresh_or_login(self, mock_cached): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'cached-token') + + @patch('sap.http.oauth.fetch_token_with_credentials') + @patch('sap.http.oauth.refresh_access_token', return_value='refreshed-token') + @patch('sap.http.oauth.get_cached_refresh_token', return_value='old-refresh') + @patch('sap.http.oauth.get_cached_token', return_value=None) + def test_uses_refresh_token_when_access_token_expired( + self, mock_cached, mock_refresh_tok, mock_refresh, mock_password): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'refreshed-token') + mock_password.assert_not_called() + + @patch('sap.http.oauth.fetch_token_with_credentials', return_value='new-login-token') + @patch('sap.http.oauth.refresh_access_token', return_value=None) + @patch('sap.http.oauth.get_cached_refresh_token', return_value='stale-refresh') + @patch('sap.http.oauth.get_cached_token', return_value=None) + def test_falls_back_to_password_when_refresh_fails( + self, mock_cached, mock_refresh_tok, mock_refresh, mock_password): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'new-login-token') + + @patch('sap.http.oauth.fetch_token_with_credentials', return_value='login-token') + @patch('sap.http.oauth.get_cached_refresh_token', return_value=None) + @patch('sap.http.oauth.get_cached_token', return_value=None) + def test_prompts_login_when_no_cache_at_all( + self, mock_cached, mock_refresh_tok, mock_password): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'login-token') + + +if __name__ == '__main__': + unittest.main()