From a7baa6f9fbe67ecbe94fd0bdbbabae13477e95f6 Mon Sep 17 00:00:00 2001 From: Leila Nefati Date: Wed, 29 Apr 2026 09:00:25 +0000 Subject: [PATCH 1/3] http: Add OAuth 2.0 support for BTP ABAP Environment (Steampunk) Co-Authored-By: Jakub Filak Co-Authored-By: Claude Opus 4.7 --- sap/cli/__init__.py | 43 +++- sap/cli/_entry.py | 11 +- sap/config.py | 1 + sap/http/oauth.py | 193 +++++++++++++++++ test/unit/test_sap_cli.py | 63 ++++++ test/unit/test_sap_http_oauth.py | 359 +++++++++++++++++++++++++++++++ 6 files changed, 668 insertions(+), 2 deletions(-) create mode 100644 sap/http/oauth.py create mode 100644 test/unit/test_sap_http_oauth.py diff --git a/sap/cli/__init__.py b/sap/cli/__init__.py index 05ad800d..82056404 100644 --- a/sap/cli/__init__.py +++ b/sap/cli/__init__.py @@ -128,10 +128,33 @@ def adt_connection_from_args(args): import sap.adt + session_initializer = _build_session_initializer(args) + return sap.adt.Connection( args.ashost, args.client, args.user, args.password, port=args.port, ssl=args.ssl, verify=args.verify, - ssl_server_cert=args.ssl_server_cert) + ssl_server_cert=args.ssl_server_cert, + session_initializer=session_initializer) + + +def _build_session_initializer(args): + """Build an OAuthHTTPSessionInitializer when args.token_url is set, + otherwise return None so HTTPClient falls back to BasicAuth. + """ + + token_url = getattr(args, 'token_url', None) + if not token_url: + return None + + from sap.http.oauth import OAuthHTTPSessionInitializer + + return OAuthHTTPSessionInitializer( + token_url, + getattr(args, 'client_id', None), + getattr(args, 'client_secret', None), + args.user, + args.password, + ) def rfc_connection_from_args(args): @@ -204,6 +227,9 @@ def build_empty_connection_values(): ssl_server_cert=None, user=None, password=None, + token_url=None, + client_id=None, + client_secret=None, ) @@ -287,6 +313,8 @@ def resolve_default_connection_values(args): if not args.password: args.password = os.getenv('SAP_PASSWORD') or config_values.get('password') + _resolve_oauth_defaults(args, config_values) + if hasattr(args, 'corrnr') and args.corrnr is None: args.corrnr = os.getenv('SAP_CORRNR') @@ -295,6 +323,19 @@ def resolve_default_connection_values(args): _apply_config_extra_params(args, config_values) +def _resolve_oauth_defaults(args, config_values): + """Resolve OAuth-specific connection defaults from env vars and config file.""" + + if not getattr(args, 'token_url', None): + args.token_url = os.getenv('SAP_TOKEN_URL') or config_values.get('token_url') + + if not getattr(args, 'client_id', None): + args.client_id = os.getenv('SAP_CLIENT_ID') or config_values.get('client_id') + + if not getattr(args, 'client_secret', None): + args.client_secret = os.getenv('SAP_CLIENT_SECRET') or config_values.get('client_secret') + + def _get_config_context_values(args): """Load config file and resolve the active context to a flat dict.""" diff --git a/sap/cli/_entry.py b/sap/cli/_entry.py index 5958b76c..0c6fd0bd 100644 --- a/sap/cli/_entry.py +++ b/sap/cli/_entry.py @@ -16,6 +16,7 @@ import sap.rfc from sap.config import ConfigFile from sap.http import TimedOutRequestError as HttpTimedOutRequestError +from sap.http.oauth import get_cached_token, get_cached_refresh_token from sap.odata.errors import TimedOutRequestError as ODataTimedOutRequestError # pylint: disable=invalid-name @@ -157,7 +158,15 @@ def parse_command_line(argv): if not args.user: args.user = input('Login:') - if not args.password: + token_url = getattr(args, 'token_url', None) + client_id = getattr(args, 'client_id', None) + has_valid_token = ( + token_url and client_id and ( + get_cached_token(token_url, client_id) + or get_cached_refresh_token(token_url, client_id) + ) + ) + if not args.password and not has_valid_token: args.password = getpass.getpass() return args diff --git a/sap/config.py b/sap/config.py index 4b4cc4c7..5631a53c 100644 --- a/sap/config.py +++ b/sap/config.py @@ -23,6 +23,7 @@ class SAPCliConfigError(SAPCliError): 'ashost', 'sysnr', 'client', 'port', 'ssl', 'ssl_verify', 'ssl_server_cert', 'mshost', 'msserv', 'sysid', 'group', 'snc_qop', 'snc_myname', 'snc_partnername', 'snc_lib', + 'token_url', 'client_id', 'client_secret', ) USER_FIELDS = ( diff --git a/sap/http/oauth.py b/sap/http/oauth.py new file mode 100644 index 00000000..0307d682 --- /dev/null +++ b/sap/http/oauth.py @@ -0,0 +1,193 @@ +"""OAuth 2.0 password grant flow with token caching for BTP Steampunk.""" + +import json +import os +import time +from pathlib import Path + +import requests +from requests.auth import AuthBase + +from sap.errors import SAPCliError +from sap.http.errors import UnauthorizedError + +TOKEN_CACHE_PATH = Path('~/.sapcli/tokens.json').expanduser() +REFRESH_MARGIN = 60 + + +class OAuthTokenError(SAPCliError): + """Raised when an OAuth token cannot be obtained from the auth server.""" + + +class BearerAuth(AuthBase): + """Requests auth handler that injects an OAuth 2.0 Bearer token.""" + + def __init__(self, token): + self._token = token + + def __call__(self, r): + r.headers['Authorization'] = f'Bearer {self._token}' + return r + + +# --------------------------------------------------------------------------- +# Token cache +# --------------------------------------------------------------------------- + +def _load_token_cache(): + try: + with open(TOKEN_CACHE_PATH, 'r', encoding='utf-8') as f: + return json.load(f) + except (OSError, json.JSONDecodeError): + # Missing or corrupt cache files are not fatal: we simply have no + # cached tokens and will fetch fresh ones. + return {} + + +def _save_token_cache(cache): + TOKEN_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) + fd = os.open(TOKEN_CACHE_PATH, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, 'w', encoding='utf-8') as f: + json.dump(cache, f, indent=2) + + +def _cache_key(token_url, client_id): + return f'{token_url}|{client_id}' + + +def get_cached_token(token_url, client_id): + """Return a non-expired cached access token, or None.""" + + cache = _load_token_cache() + entry = cache.get(_cache_key(token_url, client_id)) + if not entry: + return None + if time.time() > entry.get('expires_at', 0) - REFRESH_MARGIN: + return None + return entry['access_token'] + + +def get_cached_refresh_token(token_url, client_id): + """Return the cached refresh token, or None.""" + + cache = _load_token_cache() + entry = cache.get(_cache_key(token_url, client_id)) + if not entry: + return None + return entry.get('refresh_token') + + +def save_token_response(token_url, client_id, token_response): + """Persist an access/refresh token pair into the token cache.""" + + cache = _load_token_cache() + expires_in = token_response.get('expires_in', 3600) + cache[_cache_key(token_url, client_id)] = { + 'access_token': token_response['access_token'], + 'refresh_token': token_response.get('refresh_token'), + 'expires_at': time.time() + expires_in, + } + _save_token_cache(cache) + + +# --------------------------------------------------------------------------- +# Token refresh +# --------------------------------------------------------------------------- + +def refresh_access_token(token_url, client_id, client_secret, refresh_token): + """Try to swap a refresh token for a new access token. Returns None on failure.""" + + response = requests.post( + token_url.rstrip('/') + '/oauth/token', + auth=(client_id, client_secret), + data={'grant_type': 'refresh_token', 'refresh_token': refresh_token}, + timeout=30, + ) + if not response.ok: + return None + token_data = response.json() + save_token_response(token_url, client_id, token_data) + return token_data['access_token'] + + +# --------------------------------------------------------------------------- +# Interactive password grant +# --------------------------------------------------------------------------- + +def fetch_token_with_credentials(token_url, client_id, client_secret, user, password): + """Obtain a Bearer token via OAuth 2.0 password grant using provided credentials.""" + + response = requests.post( + token_url.rstrip('/') + '/oauth/token', + auth=(client_id, client_secret), + data={ + 'grant_type': 'password', + 'username': user, + 'password': password, + }, + timeout=30, + ) + + if not response.ok: + raise OAuthTokenError( + f'OAuth login failed ({response.status_code}): {response.text}' + ) + + token_data = response.json() + save_token_response(token_url, client_id, token_data) + return token_data['access_token'] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def get_token(token_url, client_id, client_secret, user=None, password=None): + """Return a valid Bearer token - from cache, refresh, or credentials grant.""" + + token = get_cached_token(token_url, client_id) + if token: + return token + + refresh_token = get_cached_refresh_token(token_url, client_id) + if refresh_token: + token = refresh_access_token(token_url, client_id, client_secret, refresh_token) + if token: + return token + + return fetch_token_with_credentials(token_url, client_id, client_secret, user, password) + + +# --------------------------------------------------------------------------- +# Session initializer +# --------------------------------------------------------------------------- + +class OAuthHTTPSessionInitializer: + """HTTPSessionInitializer that authenticates the session via OAuth 2.0. + + The token is fetched lazily inside initialize_session; constructing the + initializer must not perform network I/O. + """ + + # pylint: disable=too-many-arguments + def __init__(self, token_url, client_id, client_secret, user, password): + self._token_url = token_url + self._client_id = client_id + self._client_secret = client_secret + self._user = user + self._password = password + + def initialize_session(self, session): + """Fetch a Bearer token and attach it to the session.""" + + token = get_token( + self._token_url, self._client_id, self._client_secret, + user=self._user, password=self._password, + ) + session.auth = BearerAuth(token) + return session + + def build_unauthorized_error(self, req, res): + """Build an UnauthorizedError carrying the configured user.""" + + return UnauthorizedError(req, res, self._user) diff --git a/test/unit/test_sap_cli.py b/test/unit/test_sap_cli.py index 705c0718..1d6de64b 100755 --- a/test/unit/test_sap_cli.py +++ b/test/unit/test_sap_cli.py @@ -727,5 +727,68 @@ def test_no_connection_returns_none(self): self.assertIsNone(result) +class TestAdtConnectionFromArgs(unittest.TestCase): + """adt_connection_from_args wires OAuth via session_initializer + when args.token_url is set, BasicAuth otherwise. + """ + + def _make_args(self, **overrides): + defaults = dict( + ashost='h.example.com', + client='100', + user='USR', + password='pwd', + port=443, + ssl=True, + verify=True, + ssl_server_cert=None, + token_url=None, + client_id=None, + client_secret=None, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + def test_basic_auth_when_no_token_url(self): + args = self._make_args() + + with patch('sap.adt.Connection') as mock_connection: + sap.cli.adt_connection_from_args(args) + + _, kwargs = mock_connection.call_args + self.assertIsNone(kwargs.get('session_initializer')) + + def test_oauth_initializer_when_token_url_present(self): + from sap.http.oauth import OAuthHTTPSessionInitializer + args = self._make_args( + token_url='https://auth.example.com', + client_id='cid', + client_secret='csec', + ) + + with patch('sap.adt.Connection') as mock_connection: + sap.cli.adt_connection_from_args(args) + + _, kwargs = mock_connection.call_args + initializer = kwargs.get('session_initializer') + self.assertIsInstance(initializer, OAuthHTTPSessionInitializer) + + def test_oauth_kwargs_not_passed_to_connection(self): + """token_url/client_id/client_secret must not appear on Connection ctor.""" + args = self._make_args( + token_url='https://auth.example.com', + client_id='cid', + client_secret='csec', + ) + + with patch('sap.adt.Connection') as mock_connection: + sap.cli.adt_connection_from_args(args) + + _, kwargs = mock_connection.call_args + self.assertNotIn('token_url', kwargs) + self.assertNotIn('client_id', kwargs) + self.assertNotIn('client_secret', kwargs) + + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_sap_http_oauth.py b/test/unit/test_sap_http_oauth.py new file mode 100644 index 00000000..9169eb5b --- /dev/null +++ b/test/unit/test_sap_http_oauth.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +import time +import unittest +from unittest.mock import Mock, patch + +from requests.auth import HTTPBasicAuth + +from sap.http.client import HTTPClient +from sap.http.errors import UnauthorizedError +from sap.http.oauth import ( + BearerAuth, + OAuthHTTPSessionInitializer, + OAuthTokenError, + _cache_key, + fetch_token_with_credentials, + get_cached_token, + get_cached_refresh_token, + get_token, + refresh_access_token, + save_token_response, +) + + +# --------------------------------------------------------------------------- +# BearerAuth +# --------------------------------------------------------------------------- + +class TestBearerAuth(unittest.TestCase): + + def test_adds_authorization_header(self): + auth = BearerAuth('my-token') + request = Mock() + request.headers = {} + + result = auth(request) + + self.assertEqual(request.headers['Authorization'], 'Bearer my-token') + self.assertIs(result, request) + + def test_overwrites_existing_authorization_header(self): + auth = BearerAuth('new-token') + request = Mock() + request.headers = {'Authorization': 'Basic old'} + + auth(request) + + self.assertEqual(request.headers['Authorization'], 'Bearer new-token') + + +# --------------------------------------------------------------------------- +# OAuthHTTPSessionInitializer +# --------------------------------------------------------------------------- + +class TestOAuthHTTPSessionInitializer(unittest.TestCase): + + fixture_token_url = 'https://auth.example.com' + fixture_client_id = 'my-client-id' + fixture_client_secret = 'my-client-secret' + fixture_user = 'user@sap.com' + fixture_password = 'secret' + + def _make_initializer(self): + return OAuthHTTPSessionInitializer( + self.fixture_token_url, + self.fixture_client_id, + self.fixture_client_secret, + self.fixture_user, + self.fixture_password, + ) + + @patch('sap.http.oauth.get_token') + def test_does_not_fetch_token_at_construction(self, mock_get_token): + """Token fetch must be lazy and happen only inside initialize_session.""" + + self._make_initializer() + + mock_get_token.assert_not_called() + + @patch('sap.http.oauth.get_token', return_value='bearer-token-123') + def test_initialize_session_fetches_token_and_sets_bearer_auth(self, mock_get_token): + initializer = self._make_initializer() + session = Mock() + + returned = initializer.initialize_session(session) + + mock_get_token.assert_called_once_with( + self.fixture_token_url, + self.fixture_client_id, + self.fixture_client_secret, + user=self.fixture_user, + password=self.fixture_password, + ) + self.assertIs(returned, session) + self.assertIsInstance(session.auth, BearerAuth) + + @patch('sap.http.oauth.get_token', return_value='abc') + def test_initialize_session_passes_token_to_bearer_auth(self, mock_get_token): + initializer = self._make_initializer() + session = Mock() + + initializer.initialize_session(session) + + # Verify the BearerAuth carries the token returned by get_token + request = Mock() + request.headers = {} + session.auth(request) + self.assertEqual(request.headers['Authorization'], 'Bearer abc') + + def test_build_unauthorized_error_uses_user(self): + initializer = self._make_initializer() + req = Mock() + res = Mock() + + err = initializer.build_unauthorized_error(req, res) + + self.assertIsInstance(err, UnauthorizedError) + self.assertIs(err.request, req) + self.assertIs(err.response, res) + self.assertEqual(err.user, self.fixture_user) + + +# --------------------------------------------------------------------------- +# HTTPClient with OAuth initializer +# --------------------------------------------------------------------------- + +class TestHTTPClientWithOAuthInitializer(unittest.TestCase): + + def test_default_initializer_is_basic_auth(self): + client = HTTPClient(host='c50.example.com', user='ELBEZI', password='pass') + + # Default still BasicAuth — no OAuth knobs on HTTPClient ctor. + self.assertIsInstance(client._session_initializer.build_unauthorized_error(Mock(), Mock()), UnauthorizedError) + + def test_oauth_initializer_is_used_when_provided(self): + initializer = OAuthHTTPSessionInitializer( + 'https://auth.example.com', 'cid', 'csec', 'usr', 'pwd' + ) + + client = HTTPClient(host='h', user='usr', password='pwd', session_initializer=initializer) + + self.assertIs(client._session_initializer, initializer) + + +# --------------------------------------------------------------------------- +# Token cache +# --------------------------------------------------------------------------- + +class TestTokenCache(unittest.TestCase): + + def _make_token_response(self, access_token='access-123', refresh_token='refresh-456', expires_in=3600): + return { + 'access_token': access_token, + 'refresh_token': refresh_token, + 'expires_in': expires_in, + } + + @patch('sap.http.oauth._save_token_cache') + @patch('sap.http.oauth._load_token_cache', return_value={}) + def test_save_token_response_stores_entry(self, mock_load, mock_save): + save_token_response('https://auth.example.com', 'client-id', self._make_token_response()) + + saved = mock_save.call_args[0][0] + key = _cache_key('https://auth.example.com', 'client-id') + self.assertIn(key, saved) + self.assertEqual(saved[key]['access_token'], 'access-123') + self.assertEqual(saved[key]['refresh_token'], 'refresh-456') + self.assertAlmostEqual(saved[key]['expires_at'], time.time() + 3600, delta=5) + + @patch('sap.http.oauth._load_token_cache') + def test_get_cached_token_returns_valid_token(self, mock_load): + key = _cache_key('https://auth.example.com', 'client-id') + mock_load.return_value = { + key: { + 'access_token': 'valid-token', + 'expires_at': time.time() + 3600, + } + } + + token = get_cached_token('https://auth.example.com', 'client-id') + + self.assertEqual(token, 'valid-token') + + @patch('sap.http.oauth._load_token_cache') + def test_get_cached_token_returns_none_when_expired(self, mock_load): + key = _cache_key('https://auth.example.com', 'client-id') + mock_load.return_value = { + key: { + 'access_token': 'expired-token', + 'expires_at': time.time() - 10, # already expired + } + } + + token = get_cached_token('https://auth.example.com', 'client-id') + + self.assertIsNone(token) + + @patch('sap.http.oauth._load_token_cache', return_value={}) + def test_get_cached_token_returns_none_when_missing(self, mock_load): + token = get_cached_token('https://auth.example.com', 'client-id') + self.assertIsNone(token) + + @patch('sap.http.oauth._load_token_cache') + def test_get_cached_refresh_token_returns_value(self, mock_load): + key = _cache_key('https://auth.example.com', 'client-id') + mock_load.return_value = { + key: {'refresh_token': 'my-refresh-token', 'expires_at': time.time() - 1} + } + + refresh = get_cached_refresh_token('https://auth.example.com', 'client-id') + + self.assertEqual(refresh, 'my-refresh-token') + + @patch('sap.http.oauth._load_token_cache', return_value={}) + def test_get_cached_refresh_token_returns_none_when_missing(self, mock_load): + refresh = get_cached_refresh_token('https://auth.example.com', 'client-id') + self.assertIsNone(refresh) + + +# --------------------------------------------------------------------------- +# Token refresh +# --------------------------------------------------------------------------- + +class TestRefreshAccessToken(unittest.TestCase): + + @patch('sap.http.oauth.save_token_response') + @patch('sap.http.oauth.requests.post') + def test_refresh_success(self, mock_post, mock_save): + mock_post.return_value = Mock( + ok=True, + json=lambda: {'access_token': 'new-token', 'expires_in': 3600} + ) + + token = refresh_access_token( + 'https://auth.example.com', 'client-id', 'client-secret', 'old-refresh' + ) + + self.assertEqual(token, 'new-token') + mock_post.assert_called_once_with( + 'https://auth.example.com/oauth/token', + auth=('client-id', 'client-secret'), + data={'grant_type': 'refresh_token', 'refresh_token': 'old-refresh'}, + timeout=30, + ) + mock_save.assert_called_once() + + @patch('sap.http.oauth.requests.post') + def test_refresh_failure_returns_none(self, mock_post): + mock_post.return_value = Mock(ok=False, status_code=401, text='invalid') + + token = refresh_access_token( + 'https://auth.example.com', 'client-id', 'client-secret', 'bad-refresh' + ) + + self.assertIsNone(token) + + +# --------------------------------------------------------------------------- +# Interactive password grant +# --------------------------------------------------------------------------- + +class TestFetchTokenWithCredentials(unittest.TestCase): + + @patch('sap.http.oauth.save_token_response') + @patch('sap.http.oauth.requests.post') + def test_password_grant_success(self, mock_post, mock_save): + mock_post.return_value = Mock( + ok=True, + json=lambda: {'access_token': 'user-token', 'refresh_token': 'r-token', 'expires_in': 43200} + ) + + token = fetch_token_with_credentials( + 'https://auth.example.com', 'client-id', 'client-secret', + 'user@sap.com', 'mypassword' + ) + + self.assertEqual(token, 'user-token') + mock_post.assert_called_once_with( + 'https://auth.example.com/oauth/token', + auth=('client-id', 'client-secret'), + data={ + 'grant_type': 'password', + 'username': 'user@sap.com', + 'password': 'mypassword', + }, + timeout=30, + ) + mock_save.assert_called_once() + + @patch('sap.http.oauth.requests.post') + def test_password_grant_failure_raises(self, mock_post): + mock_post.return_value = Mock(ok=False, status_code=401, text='invalid_grant') + + with self.assertRaises(OAuthTokenError) as cm: + fetch_token_with_credentials( + 'https://auth.example.com', 'client-id', 'client-secret', + 'user@sap.com', 'wrongpass' + ) + + self.assertIn('401', str(cm.exception)) + self.assertIn('invalid_grant', str(cm.exception)) + + +# --------------------------------------------------------------------------- +# OAuthTokenError is a SAPCliError subclass +# --------------------------------------------------------------------------- + +class TestOAuthTokenError(unittest.TestCase): + + def test_is_sapcli_error(self): + from sap.errors import SAPCliError + self.assertTrue(issubclass(OAuthTokenError, SAPCliError)) + + +# --------------------------------------------------------------------------- +# get_token — orchestration +# --------------------------------------------------------------------------- + +class TestGetToken(unittest.TestCase): + + @patch('sap.http.oauth.get_cached_token', return_value='cached-token') + def test_returns_cached_token_without_refresh_or_login(self, mock_cached): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'cached-token') + + @patch('sap.http.oauth.fetch_token_with_credentials') + @patch('sap.http.oauth.refresh_access_token', return_value='refreshed-token') + @patch('sap.http.oauth.get_cached_refresh_token', return_value='old-refresh') + @patch('sap.http.oauth.get_cached_token', return_value=None) + def test_uses_refresh_token_when_access_token_expired( + self, mock_cached, mock_refresh_tok, mock_refresh, mock_password): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'refreshed-token') + mock_password.assert_not_called() + + @patch('sap.http.oauth.fetch_token_with_credentials', return_value='new-login-token') + @patch('sap.http.oauth.refresh_access_token', return_value=None) + @patch('sap.http.oauth.get_cached_refresh_token', return_value='stale-refresh') + @patch('sap.http.oauth.get_cached_token', return_value=None) + def test_falls_back_to_password_when_refresh_fails( + self, mock_cached, mock_refresh_tok, mock_refresh, mock_password): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'new-login-token') + + @patch('sap.http.oauth.fetch_token_with_credentials', return_value='login-token') + @patch('sap.http.oauth.get_cached_refresh_token', return_value=None) + @patch('sap.http.oauth.get_cached_token', return_value=None) + def test_prompts_login_when_no_cache_at_all( + self, mock_cached, mock_refresh_tok, mock_password): + token = get_token('https://auth.example.com', 'client-id', 'client-secret') + + self.assertEqual(token, 'login-token') + + +if __name__ == '__main__': + unittest.main() From 130397a52beaa272551eda45ea6c0a190d59db8b Mon Sep 17 00:00:00 2001 From: Jakub Filak Date: Wed, 29 Apr 2026 14:54:37 +0000 Subject: [PATCH 2/3] config: expose OAuth fields on set-connection and document OAuth MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds --token-url, --client-id, --client-secret to `sapcli config set-connection`, so users can store OAuth credentials on a connection without hand-editing ~/.sapcli/config.yml. The underlying CONNECTION_FIELDS list and the connection storage already handle these fields (introduced earlier with the OAuth seam); this change just makes them reachable from the CLI in the same way as every other connection field. Deliberately not exposed as global runtime flags. OAuth is a rare auth method in the ABAP world, and sapcli's config-file model (named connections / users / contexts, current-context, set-* commands) mirrors kubectl, where OIDC credentials likewise live only in kubeconfig — not as global flags on every command. doc/configuration.md gets: - a worked `sapcli config set-connection` example with OAuth flags in the Managing connections section, - the OAuth section reworded from "not exposed as command-line flags" to "not exposed as global flags" with an example showing how to set the values via the subcommand, - the explanatory paragraph clarifying that token_url / client_id / client_secret describe the OAuth application (per tenant, shared by team) and therefore live under connections:, not users:. Co-Authored-By: Claude Opus 4.7 --- doc/configuration.md | 105 +++++++++++++++++++++++++++++-- sap/cli/config.py | 6 ++ test/unit/test_sap_cli_config.py | 43 +++++++++++++ 3 files changed, 150 insertions(+), 4 deletions(-) diff --git a/doc/configuration.md b/doc/configuration.md index e6368bdf..f38ed001 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -7,7 +7,7 @@ a configuration file. The priority order from highest to lowest is: 2. **Environment variables** - override config file values 3. **Configuration file** (active context) - overrides defaults 4. **Built-in defaults** - used when nothing else is specified -5. **Interactive prompt** - fallback for mandatory values (user, password) when no SNC config is present +5. **Interactive prompt** - fallback for mandatory values (user, password) when no SNC config is present and no valid OAuth token is cached ## Parameters @@ -188,6 +188,23 @@ contexts: password: prod-secret # overrides user ``` +A connection that uses OAuth 2.0 instead of a password is defined the same +way, with three additional fields: + +```yaml +connections: + my-cloud-system: + ashost: my-tenant.abap.eu10.hana.ondemand.com + client: "100" + port: 443 + ssl: true + token_url: https://my-tenant.authentication.eu10.hana.ondemand.com + client_id: sb-abap!t12345 + client_secret: my-client-secret +``` + +See [OAuth 2.0 authentication](#oauth-20-authentication) below for details. + ### Field reference #### `connections.` @@ -209,9 +226,15 @@ contexts: | `snc_myname` | string | no | - | `SNC_MYNAME` | | `snc_partnername` | string | no | - | `SNC_PARTNERNAME` | | `snc_lib` | string | no | - | `SNC_LIB` | +| `token_url` | string | no | - | `SAP_TOKEN_URL` | +| `client_id` | string | no | - | `SAP_CLIENT_ID` | +| `client_secret` | string | no | - | `SAP_CLIENT_SECRET` | (*) Either `ashost` or `mshost` must be provided. +The `token_url`, `client_id`, and `client_secret` fields enable OAuth 2.0 +authentication. See [OAuth 2.0 authentication](#oauth-20-authentication) below. + #### `users.` | Field | Type | Required | Default | Env var equivalent | @@ -249,12 +272,77 @@ fields (e.g. hostname). Define one base connection and override per context. Storing passwords in plain text configuration files is a security concern. The recommended approaches, in order of preference: -1. **Omit the password from config** - sapcli will prompt interactively -2. **Use environment variables** - `SAP_PASSWORD` overrides the config file; suitable for CI/CD pipelines -3. **Store in config file** - acceptable for local development if the file has restrictive permissions (`chmod 600`) +1. **Use OAuth 2.0** - if your system supports it (e.g. SAP cloud systems), + prefer OAuth over a stored password. See + [OAuth 2.0 authentication](#oauth-20-authentication) below. +2. **Omit the password from config** - sapcli will prompt interactively +3. **Use environment variables** - `SAP_PASSWORD` overrides the config file; suitable for CI/CD pipelines +4. **Store in config file** - acceptable for local development if the file has restrictive permissions (`chmod 600`) sapcli will warn if the config file is world-readable and contains passwords. +The same caveats apply to `client_secret` when OAuth is used. + +### OAuth 2.0 authentication + +Some SAP systems — most notably SAP cloud systems such as SAP BTP ABAP +Environment ("Steampunk") — require OAuth 2.0 instead of a username/password +pair. sapcli can authenticate with OAuth and caches the obtained token +between commands so you do not need to log in every time. + +#### Enabling OAuth + +OAuth is enabled by setting three values on the connection definition, in +addition to your usual `--user`/`SAP_USER`: + +| Field on `connections.` | Env var | Description | +|---|---|---| +| `token_url` | `SAP_TOKEN_URL` | Base URL of the OAuth authorization server. sapcli appends `/oauth/token` automatically — provide the base, not the full endpoint. | +| `client_id` | `SAP_CLIENT_ID` | OAuth client ID issued by the system administrator. | +| `client_secret` | `SAP_CLIENT_SECRET` | OAuth client secret issued by the system administrator. | + +These three values are not exposed as **global** command-line flags. They can +be provided via environment variables, written directly into the configuration +file under `connections.` (see the YAML example in [Schema](#schema)), +or set with `sapcli config set-connection`: + +```bash +sapcli config set-connection my-cloud-system \ + --token-url https://my-tenant.authentication.eu10.hana.ondemand.com \ + --client-id sb-abap!t12345 \ + --client-secret +``` + +These fields describe the OAuth **application** registration on the target +system, not the individual user — that is why they sit under `connections:` +alongside `ashost`/`port`, while your user name still belongs under `users:`. +A typical setup has one OAuth client per tenant, shared by all team members, +each with their own `users..user`. + +#### How sapcli obtains a token + +The first time you run a command against an OAuth-enabled connection, sapcli +asks the OAuth server for a token using your user name and password. This is +the only step that needs your password. After that, the token is cached in +`~/.sapcli/tokens.json` (file permissions `0600`) and reused by all subsequent +commands. When the token approaches expiration, sapcli refreshes it +transparently using a refresh token — no password is needed for the refresh. + +If a valid cached token exists, sapcli does **not** prompt for a password, +even if `SAP_PASSWORD` is unset and the configuration file contains none. + +If the OAuth server rejects your credentials or is unreachable, sapcli prints +an `OAuthTokenError` with the HTTP status code and the server's response +body. Verify `token_url`, `client_id`, `client_secret`, your user name, and +your password. + +To force a fresh login (e.g. after rotating credentials), delete the cache +file: + +```bash +rm ~/.sapcli/tokens.json +``` + ## Config management commands ```bash @@ -295,6 +383,12 @@ sapcli config set-connection dev-server --ashost dev.example.com --client 100 -- # Update an existing connection (only specified fields change, others preserved) sapcli config set-connection dev-server --port 8443 +# Add OAuth 2.0 credentials to a connection (see "OAuth 2.0 authentication" below) +sapcli config set-connection cloud-srv \ + --token-url https://auth.example.com \ + --client-id sb-app!t12345 \ + --client-secret my-client-secret + # List all connections sapcli config get-connections @@ -403,6 +497,9 @@ targeting different systems. It also composes well with tools like - `SAP_PASSWORD` : default value for the command line parameter --password - `SAP_SSL_SERVER_CERT` : path to the public unencrypted server SSL certificate - `SAP_SSL_VERIFY` : if "no", SSL server certificate is no validated - this works only when SAP_SSL_SERVER_CERT is not configured +- `SAP_TOKEN_URL` : base URL of the OAuth 2.0 authorization server; corresponds to `connections..token_url` (enables OAuth authentication; see [OAuth 2.0 authentication](#oauth-20-authentication)) +- `SAP_CLIENT_ID` : OAuth 2.0 client ID; corresponds to `connections..client_id` +- `SAP_CLIENT_SECRET` : OAuth 2.0 client secret; corresponds to `connections..client_secret` - `SAP_CORRNR` : if a sapcli command accepts parameter '--corrnr', you can provide default value via this environment variable - `SAPCLI_CONFIG` : path to the configuration file (overrides the default `~/.sapcli/config.yml`) - `SAPCLI_CONTEXT` : name of the context to use (overrides `current-context` in the config file; overridden by `--context` CLI flag) diff --git a/sap/cli/config.py b/sap/cli/config.py index fde59ed6..072e270e 100644 --- a/sap/cli/config.py +++ b/sap/cli/config.py @@ -182,6 +182,12 @@ def _collect_fields(args, field_names): # -- set-connection ----------------------------------------------------------- +@CommandGroup.argument('--client-secret', dest='client_secret', default=None, + help='OAuth 2.0 client secret') +@CommandGroup.argument('--client-id', dest='client_id', default=None, + help='OAuth 2.0 client ID') +@CommandGroup.argument('--token-url', dest='token_url', default=None, + help='OAuth 2.0 authorization server base URL') @CommandGroup.argument('--snc-lib', dest='snc_lib', default=None, help='Path to SNC library') @CommandGroup.argument('--snc-partnername', dest='snc_partnername', default=None, diff --git a/test/unit/test_sap_cli_config.py b/test/unit/test_sap_cli_config.py index a4fbec8e..2243f457 100644 --- a/test/unit/test_sap_cli_config.py +++ b/test/unit/test_sap_cli_config.py @@ -631,6 +631,49 @@ def test_set_ssl_false(self): self.assertFalse(saved['connections']['no-ssl-srv']['ssl']) + def test_set_connection_persists_oauth_fields(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'config.yml') + args = SimpleNamespace( + config=path, name='cloud-srv', + ashost='cloud.example.com', client='100', port=443, + ssl=True, ssl_verify=None, ssl_server_cert=None, + sysnr=None, mshost=None, msserv=None, sysid=None, + group=None, snc_qop=None, snc_myname=None, + snc_partnername=None, snc_lib=None, + token_url='https://auth.example.com', + client_id='sb-app!t12345', + client_secret='secret-value', + ) + console = MagicMock() + with patch('sap.cli.core.get_console', return_value=console): + retval = sap.cli.config.set_connection(None, args) + + self.assertEqual(retval, 0) + + with open(path, 'r', encoding='utf-8') as f: + saved = yaml.safe_load(f) + + self.assertEqual(saved['connections']['cloud-srv']['token_url'], 'https://auth.example.com') + self.assertEqual(saved['connections']['cloud-srv']['client_id'], 'sb-app!t12345') + self.assertEqual(saved['connections']['cloud-srv']['client_secret'], 'secret-value') + + def test_set_connection_argparse_exposes_oauth_flags(self): + from argparse import ArgumentParser + parser = ArgumentParser() + sap.cli.config.CommandGroup().install_parser(parser) + + args = parser.parse_args([ + 'set-connection', 'cloud-srv', + '--token-url', 'https://auth.example.com', + '--client-id', 'sb-app!t12345', + '--client-secret', 'secret-value', + ]) + + self.assertEqual(args.token_url, 'https://auth.example.com') + self.assertEqual(args.client_id, 'sb-app!t12345') + self.assertEqual(args.client_secret, 'secret-value') + # --------------------------------------------------------------------------- # delete-connection From 3ab0cbf752896fa823a0de97ff4035fd823696b2 Mon Sep 17 00:00:00 2001 From: Jakub Filak Date: Wed, 29 Apr 2026 19:51:05 +0000 Subject: [PATCH 3/3] http: use pluginable and os agnostic Token store I was not so sure ~/.sapcli/tokens.json is the right place, I wanted to put it into the right directory (e.g. ~/.local/state/ etc.). The token store was used to be able to add more secure store in the future. --- pyproject.toml | 1 + requirements.txt | 1 + sap/cli/__init__.py | 14 +- sap/cli/_entry.py | 14 +- sap/http/oauth.py | 92 +++-- sap/http/token_cache.py | 226 +++++++++++ test/unit/mock.py | 25 +- test/unit/test_sap_cli.py | 26 ++ test/unit/test_sap_cli__entry.py | 29 ++ test/unit/test_sap_http_oauth.py | 368 ++++++++++++++--- test/unit/test_sap_http_token_cache.py | 540 +++++++++++++++++++++++++ 11 files changed, 1217 insertions(+), 119 deletions(-) create mode 100644 sap/http/token_cache.py create mode 100644 test/unit/test_sap_http_token_cache.py diff --git a/pyproject.toml b/pyproject.toml index aec7756f..3768dfde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "requests>=2.20.0", "pyodata>=1.7.0", "PyYAML>=6.0.1", + "platformdirs>=4.5.1", ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 8617a125..088b75fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ requests>=2.20.0 pyodata==1.7.0 PyYAML==6.0.1 +platformdirs>=4.5.1 diff --git a/sap/cli/__init__.py b/sap/cli/__init__.py index 82056404..d8457c11 100644 --- a/sap/cli/__init__.py +++ b/sap/cli/__init__.py @@ -11,6 +11,7 @@ from types import SimpleNamespace from sap import rfc from sap.config import SAPCliConfigError +from sap.errors import SAPCliError class CommandsCache: @@ -142,16 +143,21 @@ def _build_session_initializer(args): otherwise return None so HTTPClient falls back to BasicAuth. """ - token_url = getattr(args, 'token_url', None) - if not token_url: + token_url = args.token_url + client_id = args.client_id + client_secret = args.client_secret + if not token_url and not client_id and not client_secret: return None + if not token_url or not client_id or not client_secret: + raise SAPCliError('Invalid OAuth configuration: must set all three: token_url, client_id, client_secret') + from sap.http.oauth import OAuthHTTPSessionInitializer return OAuthHTTPSessionInitializer( token_url, - getattr(args, 'client_id', None), - getattr(args, 'client_secret', None), + client_id, + client_secret, args.user, args.password, ) diff --git a/sap/cli/_entry.py b/sap/cli/_entry.py index 0c6fd0bd..09770efc 100644 --- a/sap/cli/_entry.py +++ b/sap/cli/_entry.py @@ -16,7 +16,7 @@ import sap.rfc from sap.config import ConfigFile from sap.http import TimedOutRequestError as HttpTimedOutRequestError -from sap.http.oauth import get_cached_token, get_cached_refresh_token +import sap.http.oauth from sap.odata.errors import TimedOutRequestError as ODataTimedOutRequestError # pylint: disable=invalid-name @@ -158,15 +158,9 @@ def parse_command_line(argv): if not args.user: args.user = input('Login:') - token_url = getattr(args, 'token_url', None) - client_id = getattr(args, 'client_id', None) - has_valid_token = ( - token_url and client_id and ( - get_cached_token(token_url, client_id) - or get_cached_refresh_token(token_url, client_id) - ) - ) - if not args.password and not has_valid_token: + oauth_needs_password = sap.http.oauth.password_required(args.token_url, args.client_id) + + if not args.password and oauth_needs_password: args.password = getpass.getpass() return args diff --git a/sap/http/oauth.py b/sap/http/oauth.py index 0307d682..c318553c 100644 --- a/sap/http/oauth.py +++ b/sap/http/oauth.py @@ -1,17 +1,17 @@ """OAuth 2.0 password grant flow with token caching for BTP Steampunk.""" -import json -import os -import time -from pathlib import Path +from datetime import datetime, timedelta, timezone +from typing import Optional import requests from requests.auth import AuthBase from sap.errors import SAPCliError from sap.http.errors import UnauthorizedError +from sap.http.token_cache import get_token_store, Token + +DEFAULT_EXPIRES_IN = 3600 -TOKEN_CACHE_PATH = Path('~/.sapcli/tokens.json').expanduser() REFRESH_MARGIN = 60 @@ -34,60 +34,57 @@ def __call__(self, r): # Token cache # --------------------------------------------------------------------------- -def _load_token_cache(): - try: - with open(TOKEN_CACHE_PATH, 'r', encoding='utf-8') as f: - return json.load(f) - except (OSError, json.JSONDecodeError): - # Missing or corrupt cache files are not fatal: we simply have no - # cached tokens and will fetch fresh ones. - return {} +def _load_token(token_key: str) -> Optional[Token]: + return get_token_store().get(token_key) -def _save_token_cache(cache): - TOKEN_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) - fd = os.open(TOKEN_CACHE_PATH, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) - with os.fdopen(fd, 'w', encoding='utf-8') as f: - json.dump(cache, f, indent=2) +def _save_token(token_key: str, token: Token) -> None: + get_token_store().set(token_key, token) -def _cache_key(token_url, client_id): +def _cache_key(token_url: str, client_id: str) -> str: return f'{token_url}|{client_id}' -def get_cached_token(token_url, client_id): +def get_cached_token(token_url: str, client_id: str) -> Optional[str]: """Return a non-expired cached access token, or None.""" - cache = _load_token_cache() - entry = cache.get(_cache_key(token_url, client_id)) - if not entry: + token = _load_token(_cache_key(token_url, client_id)) + + if not token: return None - if time.time() > entry.get('expires_at', 0) - REFRESH_MARGIN: + + if token.is_expired(leeway_seconds=REFRESH_MARGIN): return None - return entry['access_token'] + return token.access_token -def get_cached_refresh_token(token_url, client_id): + +def get_cached_refresh_token(token_url: str, client_id: str): """Return the cached refresh token, or None.""" - cache = _load_token_cache() - entry = cache.get(_cache_key(token_url, client_id)) - if not entry: + token = _load_token(_cache_key(token_url, client_id)) + + if not token: return None - return entry.get('refresh_token') + return token.refresh_token -def save_token_response(token_url, client_id, token_response): + +def save_token_response(token_url: str, client_id: str, token_response: dict) -> None: """Persist an access/refresh token pair into the token cache.""" - cache = _load_token_cache() - expires_in = token_response.get('expires_in', 3600) - cache[_cache_key(token_url, client_id)] = { - 'access_token': token_response['access_token'], - 'refresh_token': token_response.get('refresh_token'), - 'expires_at': time.time() + expires_in, - } - _save_token_cache(cache) + expires_in = token_response.get('expires_in', DEFAULT_EXPIRES_IN) + expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) + + new_token = Token( + access_token=token_response['access_token'], + token_type=token_response.get('token_type', 'Bearer'), + expires_at=expires_at, + refresh_token=token_response.get('refresh_token'), + scope=token_response.get('scope'), + ) + _save_token(_cache_key(token_url, client_id), new_token) # --------------------------------------------------------------------------- @@ -191,3 +188,20 @@ def build_unauthorized_error(self, req, res): """Build an UnauthorizedError carrying the configured user.""" return UnauthorizedError(req, res, self._user) + + +# --------------------------------------------------------------------------- +# Utilities +# --------------------------------------------------------------------------- + +def password_required(token_url: Optional[str], client_id: Optional[str]) -> bool: + """Returns true if user must provide password""" + + has_valid_token = ( + token_url and client_id and ( + get_cached_token(token_url, client_id) + or get_cached_refresh_token(token_url, client_id) + ) + ) + + return not has_valid_token diff --git a/sap/http/token_cache.py b/sap/http/token_cache.py new file mode 100644 index 00000000..f61aec31 --- /dev/null +++ b/sap/http/token_cache.py @@ -0,0 +1,226 @@ +"""Token storage for thetool with a swappable backend. + +Current implementation: plaintext JSON files in the per-user cache directory. +Future implementation: OS credential manager (Keychain / Credential Manager / +Secret Service) — drop in a new TokenStore subclass and change the factory. + +Why a Token dataclass instead of dict[str, Any]. The store's contract is a +typed thing, not a bag of strings. When you swap to the keyring backend, the +new implementation has to handle the same Token shape — that constraint is what +keeps the two interchangeable. If you let dict flow through the interface, +every backend will end up disagreeing on what keys are required and you've lost +the abstraction. + +Why key-by-string instead of methods like get_github_token(). OAuth helpers +usually grow a second provider eventually (GitLab, Azure, internal IdP). Keying +by client_id (or provider:client_id) means the same store handles all of them +without interface churn. If you only ever have one token, key="default" is +fine. + +The atomic write. write_text followed by os.replace is the standard +pattern for "never leave a corrupt file on disk if interrupted." Worth keeping +even though the failure mode is rare. +""" + +from __future__ import annotations + +import json +import os +import stat +import sys +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +from platformdirs import PlatformDirs + +APP_NAME = "sapcli" + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class Token: + """An OAuth bearer token plus the metadata we care about.""" + + access_token: str + token_type: str = "Bearer" + expires_at: Optional[datetime] = None # absolute, UTC + refresh_token: Optional[str] = None + scope: Optional[str] = None + + def is_expired(self, *, leeway_seconds: int = 30) -> bool: + """Checks if the token is expired. Returns True if expired, otherwise False. + The token is expired if the remaining time is lower than leeway_seconds. + """ + + if self.expires_at is None: + return False + + now = datetime.now(timezone.utc) + + return (self.expires_at - now).total_seconds() <= leeway_seconds + + def to_json(self) -> str: + """Returns JSON string of the token""" + + d = asdict(self) + + if self.expires_at is not None: + d["expires_at"] = self.expires_at.astimezone(timezone.utc).isoformat() + + return json.dumps(d, indent=2) + + @classmethod + def from_json(cls, raw: str) -> "Token": + """Factory methods turning JSON string to Token""" + + d = json.loads(raw) + + if d.get("expires_at"): + d["expires_at"] = datetime.fromisoformat(d["expires_at"]) + + return cls(**d) + + +# --------------------------------------------------------------------------- +# Abstract backend +# --------------------------------------------------------------------------- + +class TokenStore(ABC): + """Persist and retrieve OAuth tokens, keyed by client_id (or any string). + + Implementations must be safe to call concurrently from a single user's + processes; cross-user safety is the OS's problem. + """ + + @abstractmethod + def get(self, key: str) -> Optional[Token]: + """Return the stored token for `key`, or None if absent.""" + + @abstractmethod + def set(self, key: str, token: Token) -> None: + """Store `token` under `key`, overwriting any existing entry.""" + + @abstractmethod + def delete(self, key: str) -> None: + """Remove the entry for `key`. No-op if it doesn't exist.""" + + +# --------------------------------------------------------------------------- +# File-based implementation (today) +# --------------------------------------------------------------------------- + +class FileTokenStore(TokenStore): + """Stores tokens as plaintext JSON files, one per key. + + Layout: /tokens/.json + POSIX: directory chmod 0700, file chmod 0600 + Windows: relies on %LOCALAPPDATA% ACLs being user-scoped + """ + + def __init__(self, base_dir: Optional[Path] = None) -> None: + self._base_dir = base_dir or _default_cache_dir() + self._tokens_dir = self._base_dir / "tokens" + self._tokens_dir.mkdir(parents=True, exist_ok=True) + _harden_dir(self._tokens_dir) + + # -- TokenStore ----------------------------------------------------- + + def get(self, key: str) -> Optional[Token]: + path = self._path_for(key) + if not path.exists(): + return None + try: + return Token.from_json(path.read_text(encoding="utf-8")) + except (OSError, ValueError, KeyError): + # Corrupt or unreadable — treat as absent rather than crash. + return None + + def set(self, key: str, token: Token) -> None: + path = self._path_for(key) + # Write to a temp file and rename, so we never leave a half-written + # token behind if the process is killed mid-write. + tmp = path.with_suffix(path.suffix + '.tmp') + tmp.write_text(token.to_json(), encoding='utf-8') + _harden_file(tmp) + os.replace(tmp, path) + + def delete(self, key: str) -> None: + try: + self._path_for(key).unlink() + except FileNotFoundError: + # Do not crash on such lame reason + pass + + # -- internals ------------------------------------------------------ + + def _path_for(self, key: str) -> Path: + return self._tokens_dir / f"{_sanitize(key)}.json" + + +# --------------------------------------------------------------------------- +# Factory — the one place to change when swapping backends +# --------------------------------------------------------------------------- + +_token_store: Optional[TokenStore] = None + + +def get_token_store() -> TokenStore: + """Return the configured token store. + + Today: file-based. + Tomorrow: read an env var or config flag and return a + KeyringTokenStore / DPAPITokenStore / KeychainTokenStore instead. + """ + + global _token_store # pylint: disable=global-statement + if _token_store is None: + _token_store = FileTokenStore() + + return _token_store + + +# --------------------------------------------------------------------------- +# Path + permission helpers +# --------------------------------------------------------------------------- + +def _default_cache_dir() -> Path: + dirs = PlatformDirs(appname=APP_NAME, roaming=False) + if sys.platform == "linux": + path = Path(dirs.user_state_dir) + elif sys.platform == "darwin": + path = Path(dirs.user_data_dir) + elif sys.platform == "win32": + path = Path(dirs.user_data_dir) + else: + path = Path(dirs.user_state_dir) + path.mkdir(parents=True, exist_ok=True) + _harden_dir(path) + return path + + +def _harden_dir(path: Path) -> None: + if os.name == "posix": + try: + path.chmod(stat.S_IRWXU) # 0o700 + except OSError: + pass + + +def _harden_file(path: Path) -> None: + if os.name == "posix": + try: + path.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0o600 + except OSError: + pass + + +def _sanitize(key: str) -> str: + """Make `key` safe for use as a filename component.""" + safe = "".join(c if c.isalnum() or c in "-._" else "_" for c in key) + return safe or "default" diff --git a/test/unit/mock.py b/test/unit/mock.py index 3d530fc8..ad627ef8 100644 --- a/test/unit/mock.py +++ b/test/unit/mock.py @@ -1,7 +1,7 @@ import copy import json import types -from typing import Dict, NamedTuple +from typing import Dict, NamedTuple, Optional from io import StringIO from argparse import ArgumentParser from contextlib import AbstractContextManager, contextmanager @@ -11,6 +11,7 @@ import sap.adt import sap.http +import sap.http.token_cache import sap.rest import sap.cli.core @@ -573,3 +574,25 @@ def __init__(self, buf): def __exit__(self, exc_type, exc_value, traceback): return False + + +class InMemoryTokenStore(sap.http.token_cache.TokenStore): + + def __init__(self, content: dict = None): + self.content = content or {} + + def get(self, key: str) -> Optional[sap.http.token_cache.Token]: + """Return the stored token for `key`, or None if absent.""" + + return self.content.get(key) + + def set(self, key: str, token: sap.http.token_cache.Token) -> None: + """Store `token` under `key`, overwriting any existing entry.""" + + self.content[key] = token + + def delete(self, key: str) -> None: + """Remove the entry for `key`. No-op if it doesn't exist.""" + + if key in self.content: + del self.content[key] diff --git a/test/unit/test_sap_cli.py b/test/unit/test_sap_cli.py index 1d6de64b..ed627a8a 100755 --- a/test/unit/test_sap_cli.py +++ b/test/unit/test_sap_cli.py @@ -9,6 +9,7 @@ import sap.cli import sap.cli.core from sap.config import ConfigFile, SAPCliConfigError +from sap.errors import SAPCliError from pathlib import Path @@ -789,6 +790,31 @@ def test_oauth_kwargs_not_passed_to_connection(self): self.assertNotIn('client_id', kwargs) self.assertNotIn('client_secret', kwargs) + def test_partial_oauth_config_only_token_url_raises(self): + args = self._make_args(token_url='https://auth.example.com') + + with self.assertRaises(SAPCliError) as cm: + sap.cli.adt_connection_from_args(args) + + self.assertIn('token_url', str(cm.exception)) + self.assertIn('client_id', str(cm.exception)) + self.assertIn('client_secret', str(cm.exception)) + + def test_partial_oauth_config_only_client_id_raises(self): + args = self._make_args(client_id='cid') + + with self.assertRaises(SAPCliError): + sap.cli.adt_connection_from_args(args) + + def test_partial_oauth_config_missing_secret_raises(self): + args = self._make_args( + token_url='https://auth.example.com', + client_id='cid', + ) + + with self.assertRaises(SAPCliError): + sap.cli.adt_connection_from_args(args) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_sap_cli__entry.py b/test/unit/test_sap_cli__entry.py index d4eec7f4..7ad04585 100644 --- a/test/unit/test_sap_cli__entry.py +++ b/test/unit/test_sap_cli__entry.py @@ -159,6 +159,35 @@ def test_args_ask_password(self): self.assertEqual(args.password, 'Down1oad') + def test_args_skip_password_prompt_when_token_cached(self): + """When OAuth has a usable cached token, getpass must not be called + even if --password was not supplied.""" + + test_params = get_tested_parameters() + remove_cmd_param_from_list(test_params, '--password') + + getpass_mock = Mock(return_value='should-not-be-used') + + with patch('sap.http.oauth.password_required', return_value=False) as mock_pwd_req, \ + patch('getpass.getpass', getpass_mock): + args = entry.parse_command_line(test_params) + + mock_pwd_req.assert_called_once() + getpass_mock.assert_not_called() + self.assertIsNone(args.password) + + def test_args_prompt_password_when_token_required(self): + """password_required returning True must trigger getpass prompt.""" + + test_params = get_tested_parameters() + remove_cmd_param_from_list(test_params, '--password') + + with patch('sap.http.oauth.password_required', return_value=True), \ + patch('getpass.getpass', lambda: 'prompted-pwd'): + args = entry.parse_command_line(test_params) + + self.assertEqual(args.password, 'prompted-pwd') + def test_args_ask_user_and_password(self): test_params = get_tested_parameters() remove_cmd_param_from_list(test_params, '--password') diff --git a/test/unit/test_sap_http_oauth.py b/test/unit/test_sap_http_oauth.py index 9169eb5b..25e34d5b 100644 --- a/test/unit/test_sap_http_oauth.py +++ b/test/unit/test_sap_http_oauth.py @@ -2,10 +2,9 @@ import time import unittest +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch -from requests.auth import HTTPBasicAuth - from sap.http.client import HTTPClient from sap.http.errors import UnauthorizedError from sap.http.oauth import ( @@ -17,9 +16,39 @@ get_cached_token, get_cached_refresh_token, get_token, + password_required, refresh_access_token, save_token_response, ) +from sap.http.token_cache import Token + +from test.unit.mock import InMemoryTokenStore + + +# --------------------------------------------------------------------------- +# Disk-write guard +# +# A test must never hit the default FileTokenStore: that would write under the +# real user's home directory. We patch sap.http.oauth.get_token_store at module +# scope so any path that forgets a per-test patch still gets an in-memory store +# instead of falling through to disk. Per-test @patch decorators override this +# guard for their own scope, then it is restored automatically on teardown. +# --------------------------------------------------------------------------- + +_module_token_store_patcher = None + + +def setUpModule(): + global _module_token_store_patcher + _module_token_store_patcher = patch( + 'sap.http.oauth.get_token_store', + side_effect=InMemoryTokenStore, + ) + _module_token_store_patcher.start() + + +def tearDownModule(): + _module_token_store_patcher.stop() # --------------------------------------------------------------------------- @@ -155,64 +184,129 @@ def _make_token_response(self, access_token='access-123', refresh_token='refresh 'expires_in': expires_in, } - @patch('sap.http.oauth._save_token_cache') - @patch('sap.http.oauth._load_token_cache', return_value={}) - def test_save_token_response_stores_entry(self, mock_load, mock_save): + @patch('sap.http.oauth.get_token_store') + def test_save_token_response_stores_entry(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store save_token_response('https://auth.example.com', 'client-id', self._make_token_response()) - - saved = mock_save.call_args[0][0] key = _cache_key('https://auth.example.com', 'client-id') - self.assertIn(key, saved) - self.assertEqual(saved[key]['access_token'], 'access-123') - self.assertEqual(saved[key]['refresh_token'], 'refresh-456') - self.assertAlmostEqual(saved[key]['expires_at'], time.time() + 3600, delta=5) + token = inmemory_store.get(key) + self.assertIsNotNone(token) + self.assertEqual(token.access_token, 'access-123') + self.assertEqual(token.refresh_token, 'refresh-456') + self.assertAlmostEqual(token.expires_at.timestamp(), time.time() + 3600, delta=5) + + @patch('sap.http.oauth.get_token_store') + def test_save_token_response_persists_token_type_and_scope(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + + response = self._make_token_response() + response['token_type'] = 'MAC' + response['scope'] = 'openid email' + + save_token_response('https://auth.example.com', 'client-id', response) + + token = inmemory_store.get(_cache_key('https://auth.example.com', 'client-id')) + self.assertEqual(token.token_type, 'MAC') + self.assertEqual(token.scope, 'openid email') + + @patch('sap.http.oauth.get_token_store') + def test_save_token_response_defaults_token_type_to_bearer(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + + # Response without token_type or scope. + save_token_response('https://auth.example.com', 'client-id', self._make_token_response()) + + token = inmemory_store.get(_cache_key('https://auth.example.com', 'client-id')) + self.assertEqual(token.token_type, 'Bearer') + self.assertIsNone(token.scope) + + @patch('sap.http.oauth.get_token_store') + def test_save_token_response_falls_back_to_default_expires_in(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + + # No expires_in field — implementation must fall back to DEFAULT_EXPIRES_IN (3600). + save_token_response( + 'https://auth.example.com', 'client-id', + {'access_token': 'a', 'refresh_token': 'r'}, + ) + + token = inmemory_store.get(_cache_key('https://auth.example.com', 'client-id')) + self.assertAlmostEqual(token.expires_at.timestamp(), time.time() + 3600, delta=5) + + @patch('sap.http.oauth.get_token_store') + def test_get_cached_token_returns_valid_token(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store - @patch('sap.http.oauth._load_token_cache') - def test_get_cached_token_returns_valid_token(self, mock_load): key = _cache_key('https://auth.example.com', 'client-id') - mock_load.return_value = { - key: { - 'access_token': 'valid-token', - 'expires_at': time.time() + 3600, - } - } + + inmemory_store.set( + key, + Token( + access_token='valid-token', + expires_at=datetime.now(timezone.utc) + timedelta(seconds=3600) + ) + ) token = get_cached_token('https://auth.example.com', 'client-id') self.assertEqual(token, 'valid-token') - @patch('sap.http.oauth._load_token_cache') - def test_get_cached_token_returns_none_when_expired(self, mock_load): + @patch('sap.http.oauth.get_token_store') + def test_get_cached_token_returns_none_when_expired(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + key = _cache_key('https://auth.example.com', 'client-id') - mock_load.return_value = { - key: { - 'access_token': 'expired-token', - 'expires_at': time.time() - 10, # already expired - } - } + + inmemory_store.set( + key, + Token( + access_token='valid-token', + expires_at=datetime.now(timezone.utc) - timedelta(seconds=3600) + ) + ) token = get_cached_token('https://auth.example.com', 'client-id') self.assertIsNone(token) - @patch('sap.http.oauth._load_token_cache', return_value={}) - def test_get_cached_token_returns_none_when_missing(self, mock_load): + @patch('sap.http.oauth.get_token_store') + def test_get_cached_token_returns_none_when_missing(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + token = get_cached_token('https://auth.example.com', 'client-id') self.assertIsNone(token) - @patch('sap.http.oauth._load_token_cache') - def test_get_cached_refresh_token_returns_value(self, mock_load): + @patch('sap.http.oauth.get_token_store') + def test_get_cached_refresh_token_returns_value(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + key = _cache_key('https://auth.example.com', 'client-id') - mock_load.return_value = { - key: {'refresh_token': 'my-refresh-token', 'expires_at': time.time() - 1} - } + inmemory_store.set( + key, + Token( + access_token='expired_token', + refresh_token='my-refresh-token', + expires_at=datetime.now(timezone.utc) - timedelta(seconds=3600) + ) + ) refresh = get_cached_refresh_token('https://auth.example.com', 'client-id') self.assertEqual(refresh, 'my-refresh-token') - @patch('sap.http.oauth._load_token_cache', return_value={}) - def test_get_cached_refresh_token_returns_none_when_missing(self, mock_load): + @patch('sap.http.oauth.get_token_store') + def test_get_cached_refresh_token_returns_none_when_missing(self, fake_token_store): + inmemory_store = InMemoryTokenStore() + fake_token_store.return_value = inmemory_store + refresh = get_cached_refresh_token('https://auth.example.com', 'client-id') self.assertIsNone(refresh) @@ -223,36 +317,52 @@ def test_get_cached_refresh_token_returns_none_when_missing(self, mock_load): class TestRefreshAccessToken(unittest.TestCase): - @patch('sap.http.oauth.save_token_response') + fixture_token_url = 'https://auth.example.com' + fixture_client_id = 'client-id' + fixture_client_secret = 'client-secret' + @patch('sap.http.oauth.requests.post') - def test_refresh_success(self, mock_post, mock_save): + @patch('sap.http.oauth.get_token_store') + def test_refresh_success(self, fake_token_store, mock_post): + store = InMemoryTokenStore() + fake_token_store.return_value = store + mock_post.return_value = Mock( ok=True, json=lambda: {'access_token': 'new-token', 'expires_in': 3600} ) token = refresh_access_token( - 'https://auth.example.com', 'client-id', 'client-secret', 'old-refresh' + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret, 'old-refresh' ) self.assertEqual(token, 'new-token') mock_post.assert_called_once_with( 'https://auth.example.com/oauth/token', - auth=('client-id', 'client-secret'), + auth=(self.fixture_client_id, self.fixture_client_secret), data={'grant_type': 'refresh_token', 'refresh_token': 'old-refresh'}, timeout=30, ) - mock_save.assert_called_once() + + stored = store.get(_cache_key(self.fixture_token_url, self.fixture_client_id)) + self.assertIsNotNone(stored) + self.assertEqual(stored.access_token, 'new-token') @patch('sap.http.oauth.requests.post') - def test_refresh_failure_returns_none(self, mock_post): + @patch('sap.http.oauth.get_token_store') + def test_refresh_failure_returns_none(self, fake_token_store, mock_post): + store = InMemoryTokenStore() + fake_token_store.return_value = store + mock_post.return_value = Mock(ok=False, status_code=401, text='invalid') token = refresh_access_token( - 'https://auth.example.com', 'client-id', 'client-secret', 'bad-refresh' + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret, 'bad-refresh' ) self.assertIsNone(token) + # On failure nothing should have been written to the store. + self.assertIsNone(store.get(_cache_key(self.fixture_token_url, self.fixture_client_id))) # --------------------------------------------------------------------------- @@ -261,23 +371,30 @@ def test_refresh_failure_returns_none(self, mock_post): class TestFetchTokenWithCredentials(unittest.TestCase): - @patch('sap.http.oauth.save_token_response') + fixture_token_url = 'https://auth.example.com' + fixture_client_id = 'client-id' + fixture_client_secret = 'client-secret' + @patch('sap.http.oauth.requests.post') - def test_password_grant_success(self, mock_post, mock_save): + @patch('sap.http.oauth.get_token_store') + def test_password_grant_success(self, fake_token_store, mock_post): + store = InMemoryTokenStore() + fake_token_store.return_value = store + mock_post.return_value = Mock( ok=True, json=lambda: {'access_token': 'user-token', 'refresh_token': 'r-token', 'expires_in': 43200} ) token = fetch_token_with_credentials( - 'https://auth.example.com', 'client-id', 'client-secret', + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret, 'user@sap.com', 'mypassword' ) self.assertEqual(token, 'user-token') mock_post.assert_called_once_with( 'https://auth.example.com/oauth/token', - auth=('client-id', 'client-secret'), + auth=(self.fixture_client_id, self.fixture_client_secret), data={ 'grant_type': 'password', 'username': 'user@sap.com', @@ -285,20 +402,30 @@ def test_password_grant_success(self, mock_post, mock_save): }, timeout=30, ) - mock_save.assert_called_once() + + stored = store.get(_cache_key(self.fixture_token_url, self.fixture_client_id)) + self.assertIsNotNone(stored) + self.assertEqual(stored.access_token, 'user-token') + self.assertEqual(stored.refresh_token, 'r-token') @patch('sap.http.oauth.requests.post') - def test_password_grant_failure_raises(self, mock_post): + @patch('sap.http.oauth.get_token_store') + def test_password_grant_failure_raises(self, fake_token_store, mock_post): + store = InMemoryTokenStore() + fake_token_store.return_value = store + mock_post.return_value = Mock(ok=False, status_code=401, text='invalid_grant') with self.assertRaises(OAuthTokenError) as cm: fetch_token_with_credentials( - 'https://auth.example.com', 'client-id', 'client-secret', + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret, 'user@sap.com', 'wrongpass' ) self.assertIn('401', str(cm.exception)) self.assertIn('invalid_grant', str(cm.exception)) + # Failed grant must not pollute the store. + self.assertIsNone(store.get(_cache_key(self.fixture_token_url, self.fixture_client_id))) # --------------------------------------------------------------------------- @@ -318,41 +445,152 @@ def test_is_sapcli_error(self): class TestGetToken(unittest.TestCase): - @patch('sap.http.oauth.get_cached_token', return_value='cached-token') - def test_returns_cached_token_without_refresh_or_login(self, mock_cached): - token = get_token('https://auth.example.com', 'client-id', 'client-secret') + fixture_token_url = 'https://auth.example.com' + fixture_client_id = 'client-id' + fixture_client_secret = 'client-secret' + + def _key(self): + return _cache_key(self.fixture_token_url, self.fixture_client_id) + + @patch('sap.http.oauth.fetch_token_with_credentials') + @patch('sap.http.oauth.refresh_access_token') + @patch('sap.http.oauth.get_token_store') + def test_returns_cached_token_without_refresh_or_login( + self, fake_token_store, mock_refresh, mock_password): + store = InMemoryTokenStore() + fake_token_store.return_value = store + store.set(self._key(), Token( + access_token='cached-token', + expires_at=datetime.now(timezone.utc) + timedelta(seconds=3600), + )) + + token = get_token( + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret + ) self.assertEqual(token, 'cached-token') + mock_refresh.assert_not_called() + mock_password.assert_not_called() @patch('sap.http.oauth.fetch_token_with_credentials') @patch('sap.http.oauth.refresh_access_token', return_value='refreshed-token') - @patch('sap.http.oauth.get_cached_refresh_token', return_value='old-refresh') - @patch('sap.http.oauth.get_cached_token', return_value=None) + @patch('sap.http.oauth.get_token_store') def test_uses_refresh_token_when_access_token_expired( - self, mock_cached, mock_refresh_tok, mock_refresh, mock_password): - token = get_token('https://auth.example.com', 'client-id', 'client-secret') + self, fake_token_store, mock_refresh, mock_password): + store = InMemoryTokenStore() + fake_token_store.return_value = store + store.set(self._key(), Token( + access_token='expired', + refresh_token='old-refresh', + expires_at=datetime.now(timezone.utc) - timedelta(seconds=10), + )) + + token = get_token( + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret + ) self.assertEqual(token, 'refreshed-token') + mock_refresh.assert_called_once_with( + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret, 'old-refresh', + ) mock_password.assert_not_called() @patch('sap.http.oauth.fetch_token_with_credentials', return_value='new-login-token') @patch('sap.http.oauth.refresh_access_token', return_value=None) - @patch('sap.http.oauth.get_cached_refresh_token', return_value='stale-refresh') - @patch('sap.http.oauth.get_cached_token', return_value=None) + @patch('sap.http.oauth.get_token_store') def test_falls_back_to_password_when_refresh_fails( - self, mock_cached, mock_refresh_tok, mock_refresh, mock_password): - token = get_token('https://auth.example.com', 'client-id', 'client-secret') + self, fake_token_store, mock_refresh, mock_password): + store = InMemoryTokenStore() + fake_token_store.return_value = store + store.set(self._key(), Token( + access_token='expired', + refresh_token='stale-refresh', + expires_at=datetime.now(timezone.utc) - timedelta(seconds=10), + )) + + token = get_token( + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret + ) self.assertEqual(token, 'new-login-token') + mock_refresh.assert_called_once() + mock_password.assert_called_once() @patch('sap.http.oauth.fetch_token_with_credentials', return_value='login-token') - @patch('sap.http.oauth.get_cached_refresh_token', return_value=None) - @patch('sap.http.oauth.get_cached_token', return_value=None) + @patch('sap.http.oauth.refresh_access_token') + @patch('sap.http.oauth.get_token_store') def test_prompts_login_when_no_cache_at_all( - self, mock_cached, mock_refresh_tok, mock_password): - token = get_token('https://auth.example.com', 'client-id', 'client-secret') + self, fake_token_store, mock_refresh, mock_password): + store = InMemoryTokenStore() + fake_token_store.return_value = store + + token = get_token( + self.fixture_token_url, self.fixture_client_id, self.fixture_client_secret + ) self.assertEqual(token, 'login-token') + mock_refresh.assert_not_called() + mock_password.assert_called_once() + + +# --------------------------------------------------------------------------- +# password_required +# --------------------------------------------------------------------------- + +class TestPasswordRequired(unittest.TestCase): + + fixture_token_url = 'https://auth.example.com' + fixture_client_id = 'client-id' + + def _key(self): + return _cache_key(self.fixture_token_url, self.fixture_client_id) + + @patch('sap.http.oauth.get_token_store') + def test_returns_false_when_valid_cached_token(self, fake_token_store): + store = InMemoryTokenStore() + fake_token_store.return_value = store + store.set(self._key(), Token( + access_token='cached', + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + self.assertFalse(password_required(self.fixture_token_url, self.fixture_client_id)) + + @patch('sap.http.oauth.get_token_store') + def test_returns_false_when_only_refresh_token_cached(self, fake_token_store): + store = InMemoryTokenStore() + fake_token_store.return_value = store + store.set(self._key(), Token( + access_token='expired', + refresh_token='refresh-1', + expires_at=datetime.now(timezone.utc) - timedelta(seconds=10), + )) + + self.assertFalse(password_required(self.fixture_token_url, self.fixture_client_id)) + + @patch('sap.http.oauth.get_token_store') + def test_returns_true_when_nothing_cached(self, fake_token_store): + fake_token_store.return_value = InMemoryTokenStore() + + self.assertTrue(password_required(self.fixture_token_url, self.fixture_client_id)) + + @patch('sap.http.oauth.get_token_store') + def test_returns_true_when_token_url_is_none(self, fake_token_store): + store = InMemoryTokenStore() + fake_token_store.return_value = store + # Even if a token happens to exist under this key, missing token_url means OAuth is off. + store.set(self._key(), Token( + access_token='cached', + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + self.assertTrue(password_required(None, self.fixture_client_id)) + + @patch('sap.http.oauth.get_token_store') + def test_returns_true_when_client_id_is_none(self, fake_token_store): + fake_token_store.return_value = InMemoryTokenStore() + + self.assertTrue(password_required(self.fixture_token_url, None)) if __name__ == '__main__': diff --git a/test/unit/test_sap_http_token_cache.py b/test/unit/test_sap_http_token_cache.py new file mode 100644 index 00000000..ed3193d8 --- /dev/null +++ b/test/unit/test_sap_http_token_cache.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 + +"""Tests for sap.http.token_cache. + +All filesystem and platformdirs interaction is mocked. No real file is opened, +created, chmod'd, or replaced anywhere in this module — patches on +pathlib.Path methods, sap.http.token_cache.os, and PlatformDirs guarantee that. +""" + +import json +import stat +import unittest +from dataclasses import FrozenInstanceError +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import Mock, patch + +import sap.http.token_cache as token_cache +from sap.http.token_cache import ( + FileTokenStore, + Token, + TokenStore, + _default_cache_dir, + _harden_dir, + _harden_file, + _sanitize, + get_token_store, +) + + +# --------------------------------------------------------------------------- +# Disk-write guard +# +# Patches every Path method that could touch the filesystem at module scope, so +# any path that forgets a per-test patch still cannot reach disk. Per-test +# decorators override these guards within their scope and the originals are +# restored automatically on teardown. +# --------------------------------------------------------------------------- + +_module_patchers = [] + + +def setUpModule(): + targets = [ + ('pathlib.Path.mkdir', None), + ('pathlib.Path.chmod', None), + ('pathlib.Path.exists', False), + ('pathlib.Path.unlink', None), + ('pathlib.Path.read_text', ''), + ('pathlib.Path.write_text', None), + ('sap.http.token_cache.os.replace', None), + ] + for target, return_value in targets: + patcher = patch(target, return_value=return_value) + patcher.start() + _module_patchers.append(patcher) + + +def tearDownModule(): + while _module_patchers: + _module_patchers.pop().stop() + + +# --------------------------------------------------------------------------- +# Token +# --------------------------------------------------------------------------- + +class TestTokenDefaults(unittest.TestCase): + + def test_default_token_type_is_bearer(self): + token = Token(access_token='abc') + self.assertEqual(token.token_type, 'Bearer') + + def test_default_optional_fields_are_none(self): + token = Token(access_token='abc') + self.assertIsNone(token.expires_at) + self.assertIsNone(token.refresh_token) + self.assertIsNone(token.scope) + + def test_token_is_frozen(self): + token = Token(access_token='abc') + with self.assertRaises(FrozenInstanceError): + token.access_token = 'xyz' # type: ignore[misc] + + +class TestTokenIsExpired(unittest.TestCase): + + def test_returns_false_when_no_expiry(self): + token = Token(access_token='abc') + self.assertFalse(token.is_expired()) + + def test_returns_false_when_far_in_future(self): + token = Token( + access_token='abc', + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + self.assertFalse(token.is_expired()) + + def test_returns_true_when_already_past(self): + token = Token( + access_token='abc', + expires_at=datetime.now(timezone.utc) - timedelta(seconds=10), + ) + self.assertTrue(token.is_expired()) + + def test_returns_true_when_within_default_leeway(self): + # Default leeway is 30s; expiry 10s in the future is treated as expired. + token = Token( + access_token='abc', + expires_at=datetime.now(timezone.utc) + timedelta(seconds=10), + ) + self.assertTrue(token.is_expired()) + + def test_respects_custom_leeway(self): + # 45s in the future, leeway 60 → expired; leeway 10 → not expired. + token = Token( + access_token='abc', + expires_at=datetime.now(timezone.utc) + timedelta(seconds=45), + ) + self.assertTrue(token.is_expired(leeway_seconds=60)) + self.assertFalse(token.is_expired(leeway_seconds=10)) + + +class TestTokenJsonRoundTrip(unittest.TestCase): + + def test_full_token_round_trips(self): + original = Token( + access_token='access-1', + refresh_token='refresh-1', + scope='openid email', + expires_at=datetime(2030, 1, 2, 3, 4, 5, tzinfo=timezone.utc), + ) + + recovered = Token.from_json(original.to_json()) + + self.assertEqual(recovered, original) + + def test_token_without_expiry_round_trips(self): + original = Token(access_token='access-1') + + recovered = Token.from_json(original.to_json()) + + self.assertEqual(recovered, original) + self.assertIsNone(recovered.expires_at) + + def test_to_json_emits_iso_expires_at(self): + token = Token( + access_token='abc', + expires_at=datetime(2030, 1, 2, 3, 4, 5, tzinfo=timezone.utc), + ) + + payload = json.loads(token.to_json()) + + self.assertEqual(payload['expires_at'], '2030-01-02T03:04:05+00:00') + + def test_to_json_normalises_naive_or_aware_to_utc(self): + # A non-UTC tzinfo should be converted to UTC on serialization. + non_utc = timezone(timedelta(hours=2)) + token = Token( + access_token='abc', + expires_at=datetime(2030, 1, 2, 5, 0, 0, tzinfo=non_utc), + ) + + payload = json.loads(token.to_json()) + + # 05:00+02:00 == 03:00+00:00 + self.assertEqual(payload['expires_at'], '2030-01-02T03:00:00+00:00') + + def test_from_json_handles_null_expires_at(self): + raw = json.dumps({ + 'access_token': 'abc', + 'token_type': 'Bearer', + 'expires_at': None, + 'refresh_token': None, + 'scope': None, + }) + + token = Token.from_json(raw) + + self.assertIsNone(token.expires_at) + + +# --------------------------------------------------------------------------- +# TokenStore (abstract base) +# --------------------------------------------------------------------------- + +class TestTokenStoreAbstract(unittest.TestCase): + + def test_cannot_be_instantiated_directly(self): + with self.assertRaises(TypeError): + TokenStore() # type: ignore[abstract] + + +# --------------------------------------------------------------------------- +# FileTokenStore.__init__ +# --------------------------------------------------------------------------- + +class TestFileTokenStoreInit(unittest.TestCase): + + @patch.object(Path, 'mkdir') + def test_creates_tokens_directory_under_explicit_base(self, mock_mkdir): + FileTokenStore(base_dir=Path('/fake/base')) + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + @patch('sap.http.token_cache._default_cache_dir', return_value=Path('/fake/default')) + @patch.object(Path, 'mkdir') + def test_uses_default_cache_dir_when_base_dir_is_none(self, _mock_mkdir, mock_default): + FileTokenStore() + + mock_default.assert_called_once() + + @patch.object(Path, 'chmod') + @patch.object(Path, 'mkdir') + def test_hardens_tokens_directory_on_posix(self, _mock_mkdir, mock_chmod): + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'posix' + FileTokenStore(base_dir=Path('/fake/base')) + + mock_chmod.assert_called_once_with(stat.S_IRWXU) + + @patch.object(Path, 'chmod') + @patch.object(Path, 'mkdir') + def test_does_not_chmod_on_non_posix(self, _mock_mkdir, mock_chmod): + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'nt' + FileTokenStore(base_dir=Path('/fake/base')) + + mock_chmod.assert_not_called() + + +# --------------------------------------------------------------------------- +# FileTokenStore.get +# --------------------------------------------------------------------------- + +class TestFileTokenStoreGet(unittest.TestCase): + + def setUp(self): + # __init__ runs through the module-level Path.mkdir/chmod guards. + self.store = FileTokenStore(base_dir=Path('/fake/base')) + + @patch.object(Path, 'read_text') + @patch.object(Path, 'exists', return_value=True) + def test_returns_token_when_file_exists(self, _mock_exists, mock_read): + original = Token(access_token='access-1', refresh_token='refresh-1') + mock_read.return_value = original.to_json() + + result = self.store.get('mykey') + + self.assertEqual(result, original) + + @patch.object(Path, 'exists', return_value=False) + def test_returns_none_when_file_missing(self, _mock_exists): + result = self.store.get('mykey') + + self.assertIsNone(result) + + @patch.object(Path, 'read_text', side_effect=OSError('disk error')) + @patch.object(Path, 'exists', return_value=True) + def test_returns_none_on_os_error(self, _mock_exists, _mock_read): + result = self.store.get('mykey') + + self.assertIsNone(result) + + @patch.object(Path, 'read_text', return_value='{ this is not valid JSON') + @patch.object(Path, 'exists', return_value=True) + def test_returns_none_on_invalid_json(self, _mock_exists, _mock_read): + result = self.store.get('mykey') + + self.assertIsNone(result) + + +# --------------------------------------------------------------------------- +# FileTokenStore.set +# --------------------------------------------------------------------------- + +class TestFileTokenStoreSet(unittest.TestCase): + + def setUp(self): + self.store = FileTokenStore(base_dir=Path('/fake/base')) + + @patch('sap.http.token_cache.os.replace') + @patch.object(Path, 'write_text') + def test_writes_token_then_renames_atomically(self, mock_write, mock_replace): + token = Token(access_token='abc') + + self.store.set('mykey', token) + + mock_write.assert_called_once() + mock_replace.assert_called_once() + + tmp_path, final_path = mock_replace.call_args.args + self.assertTrue(str(tmp_path).endswith('.tmp')) + self.assertFalse(str(final_path).endswith('.tmp')) + + @patch('sap.http.token_cache.os.replace') + @patch.object(Path, 'write_text') + def test_writes_serialized_json(self, mock_write, _mock_replace): + token = Token(access_token='hello-world', refresh_token='r-1') + + self.store.set('mykey', token) + + written_payload = mock_write.call_args.args[0] + self.assertIn('hello-world', written_payload) + self.assertIn('r-1', written_payload) + + @patch('sap.http.token_cache.os.replace') + @patch.object(Path, 'write_text') + def test_uses_utf8_encoding(self, mock_write, _mock_replace): + token = Token(access_token='abc') + + self.store.set('mykey', token) + + self.assertEqual(mock_write.call_args.kwargs['encoding'], 'utf-8') + + @patch('sap.http.token_cache.os.replace') + @patch.object(Path, 'chmod') + @patch.object(Path, 'write_text') + def test_hardens_tmp_file_on_posix(self, _mock_write, mock_chmod, _mock_replace): + with patch.object(token_cache, 'os') as mock_os: + # set() also calls os.replace — keep it as a mock call we don't care about + mock_os.name = 'posix' + self.store.set('mykey', Token(access_token='abc')) + + # The write_text path's chmod (file 0600). The dir's chmod happened in + # __init__ before this test, where it was patched away separately. + mock_chmod.assert_called_once_with(stat.S_IRUSR | stat.S_IWUSR) + + +# --------------------------------------------------------------------------- +# FileTokenStore.delete +# --------------------------------------------------------------------------- + +class TestFileTokenStoreDelete(unittest.TestCase): + + def setUp(self): + self.store = FileTokenStore(base_dir=Path('/fake/base')) + + @patch.object(Path, 'unlink') + def test_unlinks_existing_file(self, mock_unlink): + self.store.delete('mykey') + + mock_unlink.assert_called_once() + + @patch.object(Path, 'unlink', side_effect=FileNotFoundError()) + def test_silently_ignores_missing_file(self, _mock_unlink): + # Must not raise. + self.store.delete('mykey') + + +# --------------------------------------------------------------------------- +# _path_for / _sanitize +# --------------------------------------------------------------------------- + +class TestSanitize(unittest.TestCase): + + def test_alphanumeric_passes_through(self): + self.assertEqual(_sanitize('abcXYZ123'), 'abcXYZ123') + + def test_dash_dot_underscore_pass_through(self): + self.assertEqual(_sanitize('a-b.c_d'), 'a-b.c_d') + + def test_special_chars_become_underscore(self): + self.assertEqual(_sanitize('a/b\\c|d:e'), 'a_b_c_d_e') + + def test_pipe_replaced_with_underscore(self): + # OAuth helpers key by '|'. + self.assertEqual(_sanitize('https://x.example.com|client-1'), + 'https___x.example.com_client-1') + + def test_empty_input_returns_default(self): + self.assertEqual(_sanitize(''), 'default') + + def test_only_special_chars_does_not_collapse_to_default(self): + # The 'default' fallback only fires for empty input. + self.assertEqual(_sanitize('!!!'), '___') + + +class TestFileTokenStorePathFor(unittest.TestCase): + + def setUp(self): + self.store = FileTokenStore(base_dir=Path('/fake/base')) + + def test_path_lives_under_tokens_subdir(self): + path = self.store._path_for('mykey') + self.assertEqual(path, Path('/fake/base/tokens/mykey.json')) + + def test_path_sanitizes_special_chars(self): + path = self.store._path_for('https://x|client-1') + self.assertEqual(path, Path('/fake/base/tokens/https___x_client-1.json')) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +class TestGetTokenStore(unittest.TestCase): + + def setUp(self): + # Reset the module-level singleton so each test starts from a clean state. + token_cache._token_store = None + + def tearDown(self): + token_cache._token_store = None + + @patch('sap.http.token_cache._default_cache_dir', return_value=Path('/fake/default')) + def test_returns_file_token_store(self, _mock_default): + store = get_token_store() + + self.assertIsInstance(store, FileTokenStore) + + @patch('sap.http.token_cache._default_cache_dir', return_value=Path('/fake/default')) + def test_returns_same_instance_on_repeated_calls(self, _mock_default): + first = get_token_store() + second = get_token_store() + + self.assertIs(first, second) + + @patch('sap.http.token_cache.FileTokenStore') + @patch('sap.http.token_cache._default_cache_dir', return_value=Path('/fake/default')) + def test_constructs_file_token_store_only_once(self, _mock_default, mock_ctor): + get_token_store() + get_token_store() + get_token_store() + + mock_ctor.assert_called_once() + + +# --------------------------------------------------------------------------- +# _default_cache_dir +# --------------------------------------------------------------------------- + +class TestDefaultCacheDir(unittest.TestCase): + + def _patch_platform(self, platform_name, dirs_attrs): + """Patch sys.platform and PlatformDirs; return the mock so callers can assert.""" + + platform_dirs_mock = Mock(**dirs_attrs) + return ( + patch.object(token_cache.sys, 'platform', platform_name), + patch('sap.http.token_cache.PlatformDirs', return_value=platform_dirs_mock), + platform_dirs_mock, + ) + + def test_uses_user_state_dir_on_linux(self): + sys_p, pd_p, _pd_mock = self._patch_platform( + 'linux', {'user_state_dir': '/state', 'user_data_dir': '/data'} + ) + with sys_p, pd_p: + self.assertEqual(_default_cache_dir(), Path('/state')) + + def test_uses_user_data_dir_on_darwin(self): + sys_p, pd_p, _pd_mock = self._patch_platform( + 'darwin', {'user_state_dir': '/state', 'user_data_dir': '/data'} + ) + with sys_p, pd_p: + self.assertEqual(_default_cache_dir(), Path('/data')) + + def test_uses_user_data_dir_on_win32(self): + sys_p, pd_p, _pd_mock = self._patch_platform( + 'win32', {'user_state_dir': '/state', 'user_data_dir': '/data'} + ) + with sys_p, pd_p: + self.assertEqual(_default_cache_dir(), Path('/data')) + + def test_falls_back_to_user_state_dir_on_unknown_platform(self): + sys_p, pd_p, _pd_mock = self._patch_platform( + 'something-exotic', {'user_state_dir': '/state', 'user_data_dir': '/data'} + ) + with sys_p, pd_p: + self.assertEqual(_default_cache_dir(), Path('/state')) + + @patch.object(Path, 'mkdir') + def test_creates_directory(self, mock_mkdir): + sys_p, pd_p, _pd_mock = self._patch_platform( + 'linux', {'user_state_dir': '/state', 'user_data_dir': '/data'} + ) + with sys_p, pd_p: + _default_cache_dir() + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + +# --------------------------------------------------------------------------- +# _harden_dir / _harden_file +# --------------------------------------------------------------------------- + +class TestHardenHelpers(unittest.TestCase): + + def test_harden_dir_chmods_0700_on_posix(self): + path = Mock() + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'posix' + _harden_dir(path) + + path.chmod.assert_called_once_with(stat.S_IRWXU) + + def test_harden_dir_is_noop_on_non_posix(self): + path = Mock() + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'nt' + _harden_dir(path) + + path.chmod.assert_not_called() + + def test_harden_dir_swallows_oserror(self): + path = Mock() + path.chmod.side_effect = OSError('permission denied') + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'posix' + # Must not raise. + _harden_dir(path) + + def test_harden_file_chmods_0600_on_posix(self): + path = Mock() + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'posix' + _harden_file(path) + + path.chmod.assert_called_once_with(stat.S_IRUSR | stat.S_IWUSR) + + def test_harden_file_is_noop_on_non_posix(self): + path = Mock() + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'nt' + _harden_file(path) + + path.chmod.assert_not_called() + + def test_harden_file_swallows_oserror(self): + path = Mock() + path.chmod.side_effect = OSError('readonly') + with patch.object(token_cache, 'os') as mock_os: + mock_os.name = 'posix' + # Must not raise. + _harden_file(path) + + +if __name__ == '__main__': + unittest.main()