diff --git a/Makefile b/Makefile index 86c88a868..aecb968a0 100644 --- a/Makefile +++ b/Makefile @@ -168,4 +168,9 @@ backfill-gap-analysis-sync: python cre.py --cache_file "$$CRE_CACHE_FILE" --populate_neo4j_db && \ python cre.py --cache_file "$$CRE_CACHE_FILE" --ga_backfill_missing --ga_backfill_no_queue +backfill-opencre-ga: + @[ -d "./.venv" ] && . ./.venv/bin/activate || ([ -d "./venv" ] && . ./venv/bin/activate); \ + export FLASK_APP="$(CURDIR)/cre.py"; \ + python cre.py --cache_file "$${CRE_CACHE_FILE:-$(CURDIR)/standards_cache.sqlite}" --ga_backfill_opencre_direct + all: clean lint test dev dev-run diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index b87085aa6..b3dfe91c4 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -730,12 +730,14 @@ def backfill_gap_analysis_only( if os.environ.get("CRE_NO_NEO4J") != "1": populate_neo4j_db(db_connection_str) + gap_analysis.backfill_opencre_direct_pairs(collection, refresh=True) + missing = _missing_ga_pairs(collection) if max_pairs > 0: missing = missing[:max_pairs] total = len(missing) if total == 0: - logger.info("GA backfill: no missing pairs") + logger.info("GA backfill: no missing neo4j pairs") return logger.info( @@ -952,6 +954,9 @@ def run(args: argparse.Namespace) -> None: # pragma: no cover if args.preload_map_analysis_target_url: gap_analysis.preload(target_url=args.preload_map_analysis_target_url) + if getattr(args, "ga_backfill_opencre_direct", False): + collection = db_connect(path=args.cache_file) + gap_analysis.backfill_opencre_direct_pairs(collection, refresh=True) if getattr(args, "ga_backfill_missing", False): backfill_gap_analysis_only( args.cache_file, diff --git a/application/tests/opencre_gap_analysis_test.py b/application/tests/opencre_gap_analysis_test.py new file mode 100644 index 000000000..7ee71575b --- /dev/null +++ b/application/tests/opencre_gap_analysis_test.py @@ -0,0 +1,89 @@ +import json +import unittest +from unittest.mock import Mock, patch + +from application import create_app, sqla # type: ignore +from application.database import db +from application.defs import cre_defs as defs +from application.utils.gap_analysis import ( + OPENCRE_STANDARD_NAME, + backfill_opencre_direct_pairs, + make_resources_key, +) + + +class TestOpencreGapAnalysis(unittest.TestCase): + def tearDown(self) -> None: + sqla.session.remove() + sqla.drop_all() + self.app_context.pop() + + def setUp(self) -> None: + self.app = create_app(mode="test") + self.app_context = self.app.app_context() + self.app_context.push() + sqla.create_all() + self.collection = db.Node_collection() + + def test_backfill_populates_secure_headers_pair_from_auto_linked_nodes( + self, + ) -> None: + cre = self.collection.add_cre( + defs.CRE( + id="636-347", + name="HTTP security headers", + description="", + ) + ) + header_node = self.collection.add_node( + defs.Standard( + name="Secure Headers", + section="Prevent information disclosure via HTTP headers", + hyperlink="https://owasp.org/example", + ) + ) + self.collection.add_link( + cre=cre, + node=header_node, + ltype=defs.LinkTypes.AutomaticallyLinkedTo, + ) + + written = backfill_opencre_direct_pairs(self.collection, refresh=True) + cache_key = make_resources_key([OPENCRE_STANDARD_NAME, "Secure Headers"]) + + self.assertGreaterEqual(written, 1) + self.assertTrue(self.collection.gap_analysis_exists(cache_key)) + payload = json.loads(self.collection.get_gap_analysis_result(cache_key)) + self.assertIn("636-347", payload["result"]) + path = next(iter(payload["result"]["636-347"]["paths"].values())) + self.assertEqual("AUTOMATICALLY_LINKED_TO", path["path"][0]["relationship"]) + + @patch( + "application.utils.gap_analysis.build_direct_cre_overlap_map_analysis", + return_value={"result": {"x": {}}}, + ) + def test_backfill_refresh_recomputes_cached_pairs(self, build_mock: Mock) -> None: + collection = Mock() + collection.standards.return_value = ["ASVS"] + + backfill_opencre_direct_pairs(collection, refresh=False) + build_mock.assert_not_called() + + backfill_opencre_direct_pairs(collection, refresh=True) + self.assertEqual(2, build_mock.call_count) + + @patch( + "application.utils.gap_analysis.build_direct_cre_overlap_map_analysis", + return_value={"result": {"x": {}}}, + ) + def test_backfill_missing_only_skips_when_cache_exists( + self, build_mock: Mock + ) -> None: + collection = Mock() + collection.standards.return_value = ["ASVS"] + collection.gap_analysis_exists.return_value = True + + written = backfill_opencre_direct_pairs(collection, refresh=False) + + self.assertEqual(0, written) + build_mock.assert_not_called() diff --git a/application/tests/pci_dss_parser_test.py b/application/tests/pci_dss_parser_test.py index 3e51a8624..01815e5b9 100644 --- a/application/tests/pci_dss_parser_test.py +++ b/application/tests/pci_dss_parser_test.py @@ -1,7 +1,79 @@ import unittest from unittest.mock import Mock, patch -from application.utils.external_project_parsers.parsers.pci_dss import PciDss +from application.defs import cre_defs as defs +from application.utils.external_project_parsers.parsers import pci_dss as pci_mod +from application.utils.external_project_parsers.parsers.pci_dss import ( + PciDss, + PciDssLinkError, + best_cre_via_bridge_standard, + pci_control_embedding_text, + resolve_cre_for_pci_control, +) + + +class TestPciDssLinking(unittest.TestCase): + def test_pci_control_embedding_text_uses_id_section_and_description(self) -> None: + control = defs.Standard( + name="PCI DSS", + sectionID="1.2.3", + section="Requirement title", + description="Longer requirement body", + ) + text = pci_control_embedding_text(control) + self.assertIn("1.2.3", text) + self.assertIn("Requirement title", text) + self.assertIn("Longer requirement body", text) + self.assertNotIn("family:standard", text) + + def test_resolve_cre_uses_paginated_cre_match_first(self) -> None: + cache = Mock() + linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="") + cache.get_cre_by_db_id.return_value = linked_cre + prompt = Mock() + prompt.get_id_of_most_similar_cre_paginated.return_value = ("cre-db-id", 0.82) + + cre = resolve_cre_for_pci_control(prompt, cache, [0.1, 0.2]) + + self.assertEqual(linked_cre, cre) + prompt.get_id_of_most_similar_cre_paginated.assert_called() + prompt.get_id_of_most_similar_node.assert_not_called() + + def test_resolve_cre_falls_back_to_bridge_standard(self) -> None: + cache = Mock() + prompt = Mock() + prompt.get_id_of_most_similar_cre_paginated.return_value = (None, None) + bridge_cre = defs.CRE(id="999-001", name="Bridge CRE", description="") + + with patch.object(pci_mod, "PCI_BRIDGE_STANDARDS", ("S1", "S2")), patch.object( + pci_mod, "best_cre_via_bridge_standard", side_effect=[None, bridge_cre] + ) as bridge_mock: + cre = resolve_cre_for_pci_control(prompt, cache, [0.1, 0.2]) + + self.assertEqual(bridge_cre, cre) + self.assertEqual(2, bridge_mock.call_count) + + def test_best_cre_via_bridge_standard_picks_highest_similarity_linked_node( + self, + ) -> None: + cache = Mock() + low_node = defs.Standard(name="NIST 800-53 v5", section="low", sectionID="a") + high_node = defs.Standard(name="NIST 800-53 v5", section="high", sectionID="b") + low_cre = defs.CRE(id="111-111", name="Low", description="") + high_cre = defs.CRE(id="222-222", name="High", description="") + cache.get_nodes.return_value = [low_node, high_node] + cache.get_embeddings_for_doc.side_effect = [[0.0, 1.0], [1.0, 0.0]] + cache.find_cres_of_node.side_effect = [ + [Mock(id="low-db")], + [Mock(id="high-db")], + ] + cache.get_cre_by_db_id.side_effect = [low_cre, high_cre] + + cre = best_cre_via_bridge_standard( + cache, [1.0, 0.0], "NIST 800-53 v5", min_similarity=0.0 + ) + + self.assertEqual(high_cre, cre) class TestPciDssParser(unittest.TestCase): @@ -15,14 +87,11 @@ def test_parse_skips_standard_fallback_when_no_standard_id( cache = Mock() cache.get_nodes.return_value = None - cache.find_cres_of_node.return_value = [] - cache.get_cre_by_db_id.return_value = None - cache.get_embeddings_by_doc_type.return_value = {} + linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="") + cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]} prompt = Mock() prompt.get_text_embeddings.return_value = [0.1, 0.2] - prompt.get_id_of_most_similar_cre.return_value = None - prompt.get_id_of_most_similar_node.return_value = None prompt_handler_mock.return_value = prompt pci_file = { @@ -35,19 +104,115 @@ def test_parse_skips_standard_fallback_when_no_standard_id( ] } - out = parser.parse_4(pci_file=pci_file, cache=cache) + with patch.object( + pci_mod, "resolve_cre_for_pci_control", return_value=linked_cre + ): + out = parser.parse_4(pci_file=pci_file, cache=cache) self.assertEqual(1, len(out)) - self.assertEqual(1, cache.get_nodes.call_count) - self.assertEqual( - { - "name": "PCI DSS", - "section": "Test requirement text", - "sectionID": "1.1.1", - }, - cache.get_nodes.call_args.kwargs, - ) - prompt.generate_embeddings_for.assert_called_once() + self.assertEqual(1, len(out[0].links)) + prompt.generate_embeddings_for.assert_not_called() + + @patch( + "application.utils.external_project_parsers.parsers.pci_dss.prompt_client.PromptHandler" + ) + def test_parse_raises_when_control_cannot_be_linked(self, prompt_handler_mock): + parser = PciDss() + + cache = Mock() + cache.get_nodes.return_value = None + cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]} + + prompt = Mock() + prompt.get_text_embeddings.return_value = [0.1, 0.2] + prompt_handler_mock.return_value = prompt + + pci_file = { + "Original Content": [ + { + "Defined Approach Requirements": "Test requirement text", + "PCI DSS ID": "1.1.1", + "Requirement Description": "desc", + } + ] + } + + with patch.object(pci_mod, "resolve_cre_for_pci_control", return_value=None): + with self.assertRaises(PciDssLinkError): + parser.parse_4(pci_file=pci_file, cache=cache) + + @patch( + "application.utils.external_project_parsers.parsers.pci_dss.prompt_client.PromptHandler" + ) + def test_parse_raises_when_any_control_in_batch_is_unlinked( + self, prompt_handler_mock + ): + parser = PciDss() + cache = Mock() + cache.get_nodes.return_value = None + linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="") + cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]} + + prompt = Mock() + prompt.get_text_embeddings.return_value = [0.1, 0.2] + prompt_handler_mock.return_value = prompt + + pci_file = { + "Original Content": [ + { + "Defined Approach Requirements": "Linked requirement", + "PCI DSS ID": "1.1.1", + "Requirement Description": "desc", + }, + { + "Defined Approach Requirements": "Unlinked requirement", + "PCI DSS ID": "1.1.2", + "Requirement Description": "desc", + }, + ] + } + + with patch.object( + pci_mod, + "resolve_cre_for_pci_control", + side_effect=[linked_cre, None], + ): + with self.assertRaisesRegex(PciDssLinkError, "1 control\\(s\\) failed"): + parser.parse_4(pci_file=pci_file, cache=cache) + + @patch( + "application.utils.external_project_parsers.parsers.pci_dss.prompt_client.PromptHandler" + ) + def test_parse_adds_single_automatic_link_per_control(self, prompt_handler_mock): + parser = PciDss() + cache = Mock() + cache.get_nodes.return_value = None + linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="") + cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]} + + prompt = Mock() + prompt.get_text_embeddings.return_value = [0.1, 0.2] + prompt_handler_mock.return_value = prompt + + pci_file = { + "Original Content": [ + { + "Defined Approach Requirements": "Test requirement text", + "PCI DSS ID": "1.1.1", + "Requirement Description": "desc", + } + ] + } + + with patch.object( + pci_mod, "resolve_cre_for_pci_control", return_value=linked_cre + ): + out = parser.parse_4(pci_file=pci_file, cache=cache) + + self.assertEqual(1, len(out)) + self.assertEqual(1, len(out[0].links)) + self.assertEqual(defs.LinkTypes.AutomaticallyLinkedTo, out[0].links[0].ltype) + self.assertEqual("123-456", out[0].links[0].document.id) if __name__ == "__main__": diff --git a/application/tests/secure_headers_parser_test.py b/application/tests/secure_headers_parser_test.py index 094f0b862..1acb82219 100644 --- a/application/tests/secure_headers_parser_test.py +++ b/application/tests/secure_headers_parser_test.py @@ -3,6 +3,9 @@ from application import create_app, sqla # type: ignore from application.prompt_client.prompt_client import PromptHandler from application.utils.external_project_parsers.parsers import secure_headers +from application.utils.external_project_parsers.parsers.secure_headers import ( + SecureHeadersLinkError, +) from application.database import db from application.utils import git import tempfile @@ -45,7 +48,7 @@ class Repo: name="Secure Headers", hyperlink="https://example.com/foo/bar", section="headerAsection", - links=[defs.Link(document=cre, ltype=defs.LinkTypes.LinkedTo)], + links=[defs.Link(document=cre, ltype=defs.LinkTypes.AutomaticallyLinkedTo)], tags=[ "family:guidance", "subtype:cheatsheet", @@ -61,6 +64,119 @@ class Repo: self.assertEqual(len(nodes), 1) self.assertCountEqual(expected.todict(), nodes[0].todict()) + @patch.object(git, "clone") + def test_register_headers_creates_one_entry_per_opencre_link( + self, mock_clone + ) -> None: + class Repo: + working_dir = "" + + repo = Repo() + loc = tempfile.mkdtemp() + tmpdir = os.path.join(loc, "content") + os.mkdir(tmpdir) + repo.working_dir = loc + cre_a = defs.CRE(name="HTTP security headers", id="636-347") + cre_b = defs.CRE( + name="Do not disclose technical information in HTTP header or response", + id="743-110", + ) + self.collection.add_cre(cre_a) + self.collection.add_cre(cre_b) + md = """See [first](https://www.opencre.org/cre/636-347?name=Secure+Headers§ion=First&link=https%3A%2F%2Fexample.com%2Ffirst) +and [second](https://www.opencre.org/cre/403-005?name=Secure+Headers§ion=Second&link=https%3A%2F%2Fexample.com%2Fsecond) +""" + with open(os.path.join(tmpdir, "cs.md"), "w") as mdf: + mdf.write(md) + mock_clone.return_value = repo + entries = secure_headers.SecureHeaders().parse( + cache=self.collection, ph=PromptHandler(database=self.collection) + ) + nodes = entries.results[secure_headers.SecureHeaders().name] + self.assertEqual(2, len(nodes)) + self.assertEqual( + { + "First": "636-347", + "Second": "743-110", + }, + {node.section: node.links[0].document.id for node in nodes}, + ) + + @patch.object(git, "clone") + def test_register_headers_raises_for_unknown_cre_id(self, mock_clone) -> None: + class Repo: + working_dir = "" + + repo = Repo() + loc = tempfile.mkdtemp() + tmpdir = os.path.join(loc, "content") + os.mkdir(tmpdir) + repo.working_dir = loc + md = """See [missing](https://www.opencre.org/cre/999-999?name=Secure+Headers§ion=Missing&link=https%3A%2F%2Fexample.com%2Fmissing) +""" + with open(os.path.join(tmpdir, "cs.md"), "w") as mdf: + mdf.write(md) + mock_clone.return_value = repo + + with self.assertRaises(SecureHeadersLinkError): + secure_headers.SecureHeaders().parse( + cache=self.collection, ph=PromptHandler(database=self.collection) + ) + + @patch.object(git, "clone") + def test_register_headers_keeps_first_link_when_it_is_the_only_valid_one( + self, mock_clone + ) -> None: + class Repo: + working_dir = "" + + repo = Repo() + loc = tempfile.mkdtemp() + tmpdir = os.path.join(loc, "content") + os.mkdir(tmpdir) + repo.working_dir = loc + cre = defs.CRE(name="HTTP security headers", id="636-347") + self.collection.add_cre(cre) + md = """See [first](https://www.opencre.org/cre/636-347?name=Secure+Headers§ion=First&link=https%3A%2F%2Fexample.com%2Ffirst) +""" + with open(os.path.join(tmpdir, "cs.md"), "w") as mdf: + mdf.write(md) + mock_clone.return_value = repo + + entries = secure_headers.SecureHeaders().register_headers( + cache=self.collection, repo=repo, file_path="./", repo_path="" + ) + + self.assertEqual(1, len(entries)) + self.assertEqual("First", entries[0].section) + self.assertEqual("636-347", entries[0].links[0].document.id) + + @patch.object(git, "clone") + def test_register_headers_raises_when_later_link_is_unknown( + self, mock_clone + ) -> None: + class Repo: + working_dir = "" + + repo = Repo() + loc = tempfile.mkdtemp() + tmpdir = os.path.join(loc, "content") + os.mkdir(tmpdir) + repo.working_dir = loc + cre = defs.CRE(name="HTTP security headers", id="636-347") + self.collection.add_cre(cre) + md = """See [first](https://www.opencre.org/cre/636-347?name=Secure+Headers§ion=First&link=https%3A%2F%2Fexample.com%2Ffirst) +and [missing](https://www.opencre.org/cre/999-999?name=Secure+Headers§ion=Missing&link=https%3A%2F%2Fexample.com%2Fmissing) +""" + with open(os.path.join(tmpdir, "cs.md"), "w") as mdf: + mdf.write(md) + mock_clone.return_value = repo + + with self.assertRaises(SecureHeadersLinkError): + secure_headers.SecureHeaders().register_headers( + cache=self.collection, repo=repo, file_path="./", repo_path="" + ) + md = """ # Secure Headers 1. [Introduction](#1-Introduction) diff --git a/application/tests/web_main_test.py b/application/tests/web_main_test.py index e5d285acd..a42e887df 100644 --- a/application/tests/web_main_test.py +++ b/application/tests/web_main_test.py @@ -749,6 +749,22 @@ def test_gap_analysis_heroku_cache_miss_returns_404( self.assertEqual(404, response.status_code) redis_conn_mock.assert_not_called() + @patch.dict(os.environ, {"HEROKU": "True"}, clear=False) + @patch.object(db, "Node_collection") + @patch.object(redis, "from_url") + def test_map_analysis_opencre_heroku_cache_miss_returns_404( + self, redis_conn_mock, db_mock + ) -> None: + db_mock.return_value.gap_analysis_exists.return_value = False + with self.app.test_client() as client: + response = client.get( + "/rest/v1/map_analysis?standard=OpenCRE&standard=NIST%20800-53%20v5", + headers={"Content-Type": "application/json"}, + ) + self.assertEqual(404, response.status_code) + db_mock.return_value.get_nodes.assert_not_called() + redis_conn_mock.assert_not_called() + @patch.object(redis, "from_url") @patch.object(db, "Node_collection") def test_standards_from_db(self, node_mock, redis_conn_mock) -> None: @@ -776,20 +792,12 @@ def test_gap_analysis_supports_opencre_as_standard( compare.add_link( defs.Link(ltype=defs.LinkTypes.LinkedTo, document=shared_cre.shallow_copy()) ) - opencre = defs.CRE(id="170-772", name="Cryptography", description="") - opencre.add_link( - defs.Link(ltype=defs.LinkTypes.LinkedTo, document=compare.shallow_copy()) - ) db_mock.return_value.get_gap_analysis_result.return_value = None db_mock.return_value.gap_analysis_exists.return_value = False db_mock.return_value.get_nodes.side_effect = lambda name=None, **kwargs: ( [compare] if name == "OWASP Web Security Testing Guide (WSTG)" else [] ) - db_mock.return_value.session.query.return_value.all.return_value = [ - SimpleNamespace(id="cre-internal-1") - ] - db_mock.return_value.get_CREs.return_value = [opencre] with self.app.test_client() as client: response = client.get( @@ -800,15 +808,35 @@ def test_gap_analysis_supports_opencre_as_standard( payload = json.loads(response.data) self.assertEqual(200, response.status_code) self.assertIn("result", payload) - self.assertIn(opencre.id, payload["result"]) - self.assertEqual(1, len(payload["result"][opencre.id]["paths"])) - path = next(iter(payload["result"][opencre.id]["paths"].values())) + self.assertIn(shared_cre.id, payload["result"]) + self.assertEqual(1, len(payload["result"][shared_cre.id]["paths"])) + path = next(iter(payload["result"][shared_cre.id]["paths"].values())) self.assertEqual(compare.id, path["end"]["id"]) schedule_mock.assert_not_called() @patch.object(web_main.gap_analysis, "schedule") @patch.object(db, "Node_collection") - def test_gap_analysis_returns_only_direct_opencre_mappings( + def test_map_analysis_opencre_pair_returns_cached_result( + self, db_mock, schedule_mock + ) -> None: + expected = {"result": {"170-772": {"start": {"id": "170-772"}, "paths": {}}}} + db_mock.return_value.gap_analysis_exists.return_value = True + db_mock.return_value.get_gap_analysis_result.return_value = json.dumps(expected) + + with self.app.test_client() as client: + response = client.get( + "/rest/v1/map_analysis?standard=OpenCRE&standard=NIST%20800-53%20v5", + headers={"Content-Type": "application/json"}, + ) + + self.assertEqual(200, response.status_code) + self.assertEqual(expected, json.loads(response.data)) + db_mock.return_value.get_nodes.assert_not_called() + schedule_mock.assert_not_called() + + @patch.object(web_main.gap_analysis, "schedule") + @patch.object(db, "Node_collection") + def test_gap_analysis_opencre_mappings_include_linked_and_auto( self, db_mock, schedule_mock ) -> None: compare = defs.Standard( @@ -821,9 +849,6 @@ def test_gap_analysis_returns_only_direct_opencre_mappings( name="Set httponly attribute for cookie-based session tokens", description="", ) - direct_cre.add_link( - defs.Link(ltype=defs.LinkTypes.LinkedTo, document=compare.shallow_copy()) - ) auto_linked_cres = [] for i, cre_id in enumerate( [ @@ -842,33 +867,24 @@ def test_gap_analysis_returns_only_direct_opencre_mappings( name=f"Automatically mapped CRE {i}", description="", ) - cre.add_link( + auto_linked_cres.append(cre) + + compare.add_link( + defs.Link(ltype=defs.LinkTypes.LinkedTo, document=direct_cre.shallow_copy()) + ) + for cre in auto_linked_cres: + compare.add_link( defs.Link( ltype=defs.LinkTypes.AutomaticallyLinkedTo, - document=compare.shallow_copy(), + document=cre.shallow_copy(), ) ) - auto_linked_cres.append(cre) - - opencre_documents = [direct_cre] + auto_linked_cres - internal_ids = [ - SimpleNamespace(id=f"cre-internal-{i}") - for i in range(len(opencre_documents)) - ] db_mock.return_value.get_gap_analysis_result.return_value = None db_mock.return_value.gap_analysis_exists.return_value = False db_mock.return_value.get_nodes.side_effect = lambda name=None, **kwargs: ( [compare] if name == "CWE" else [] ) - db_mock.return_value.session.query.return_value.all.return_value = internal_ids - db_mock.return_value.get_CREs.side_effect = lambda internal_id=None, **kwargs: [ - next( - cre - for index, cre in enumerate(opencre_documents) - if internal_id == f"cre-internal-{index}" - ) - ] with self.app.test_client() as client: response = client.get( @@ -880,12 +896,17 @@ def test_gap_analysis_returns_only_direct_opencre_mappings( self.assertEqual(200, response.status_code) self.assertIn("result", payload) self.assertEqual([compare.id], list(payload["result"].keys())) - self.assertEqual(1, len(payload["result"][compare.id]["paths"])) - path = next(iter(payload["result"][compare.id]["paths"].values())) + self.assertEqual(8, len(payload["result"][compare.id]["paths"])) + path = payload["result"][compare.id]["paths"][direct_cre.id] self.assertEqual(compare.id, payload["result"][compare.id]["start"]["id"]) self.assertEqual(direct_cre.name, path["end"]["name"]) self.assertEqual("", path["path"][0]["start"]["id"]) self.assertEqual(direct_cre.id, path["path"][0]["end"]["id"]) + self.assertEqual("LINKED_TO", path["path"][0]["relationship"]) + auto_path = payload["result"][compare.id]["paths"][auto_linked_cres[0].id] + self.assertEqual( + "AUTOMATICALLY_LINKED_TO", auto_path["path"][0]["relationship"] + ) schedule_mock.assert_not_called() @patch.object(web_main.gap_analysis, "schedule") @@ -903,40 +924,26 @@ def test_gap_analysis_returns_only_direct_opencre_mappings_when_opencre_is_left( name="Set httponly attribute for cookie-based session tokens", description="", ) - direct_cre.add_link( - defs.Link(ltype=defs.LinkTypes.LinkedTo, document=compare.shallow_copy()) - ) indirect_cre = defs.CRE( id="117-371", name="Use a centralized access control mechanism", description="", ) - indirect_cre.add_link( + compare.add_link( + defs.Link(ltype=defs.LinkTypes.LinkedTo, document=direct_cre.shallow_copy()) + ) + compare.add_link( defs.Link( ltype=defs.LinkTypes.AutomaticallyLinkedTo, - document=compare.shallow_copy(), + document=indirect_cre.shallow_copy(), ) ) - opencre_documents = [direct_cre, indirect_cre] - internal_ids = [ - SimpleNamespace(id=f"cre-internal-{i}") - for i in range(len(opencre_documents)) - ] - db_mock.return_value.get_gap_analysis_result.return_value = None db_mock.return_value.gap_analysis_exists.return_value = False db_mock.return_value.get_nodes.side_effect = lambda name=None, **kwargs: ( [compare] if name == "CWE" else [] ) - db_mock.return_value.session.query.return_value.all.return_value = internal_ids - db_mock.return_value.get_CREs.side_effect = lambda internal_id=None, **kwargs: [ - next( - cre - for index, cre in enumerate(opencre_documents) - if internal_id == f"cre-internal-{index}" - ) - ] with self.app.test_client() as client: response = client.get( @@ -946,13 +953,19 @@ def test_gap_analysis_returns_only_direct_opencre_mappings_when_opencre_is_left( payload = json.loads(response.data) self.assertEqual(200, response.status_code) - self.assertEqual([direct_cre.id], list(payload["result"].keys())) + self.assertEqual( + sorted([direct_cre.id, indirect_cre.id]), + sorted(payload["result"].keys()), + ) self.assertEqual(1, len(payload["result"][direct_cre.id]["paths"])) path = next(iter(payload["result"][direct_cre.id]["paths"].values())) self.assertEqual(direct_cre.id, payload["result"][direct_cre.id]["start"]["id"]) self.assertEqual(compare.id, path["end"]["id"]) - self.assertEqual(direct_cre.id, path["path"][0]["start"]["id"]) - self.assertEqual(compare.id, path["path"][0]["end"]["id"]) + self.assertEqual("LINKED_TO", path["path"][0]["relationship"]) + auto_path = next(iter(payload["result"][indirect_cre.id]["paths"].values())) + self.assertEqual( + "AUTOMATICALLY_LINKED_TO", auto_path["path"][0]["relationship"] + ) schedule_mock.assert_not_called() @patch.object(cre_main, "resource_name_ga_eligible_in_db") diff --git a/application/utils/external_project_parsers/parsers/pci_dss.py b/application/utils/external_project_parsers/parsers/pci_dss.py index 51f681692..33884436e 100644 --- a/application/utils/external_project_parsers/parsers/pci_dss.py +++ b/application/utils/external_project_parsers/parsers/pci_dss.py @@ -1,7 +1,7 @@ from pprint import pprint import logging import os -from typing import Dict, Any +from typing import Dict, Any, List, Optional from application.database import db from application.defs import cre_defs as defs import re @@ -17,6 +17,154 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +_DEFAULT_PCI_DSS_CRE_SIMILARITY_THRESHOLDS = (0.55, 0.45, 0.35) +_DEFAULT_PCI_BRIDGE_STANDARDS = ("NIST 800-53 v5", "ISO 27001", "ASVS", "CWE") +_DEFAULT_PCI_BRIDGE_MIN_SIMILARITY = 0.4 + + +def _parse_float_env(name: str, default: float) -> float: + raw = os.environ.get(name, "").strip() + if not raw: + return default + try: + return float(raw) + except ValueError: + logger.warning("Invalid %s=%r; using default %s", name, raw, default) + return default + + +def _parse_float_tuple_env(name: str, default: tuple[float, ...]) -> tuple[float, ...]: + raw = os.environ.get(name, "").strip() + if not raw: + return default + try: + values = tuple(float(part.strip()) for part in raw.split(",") if part.strip()) + except ValueError: + logger.warning("Invalid %s=%r; using defaults %s", name, raw, default) + return default + return values or default + + +def _parse_str_tuple_env(name: str, default: tuple[str, ...]) -> tuple[str, ...]: + raw = os.environ.get(name, "").strip() + if not raw: + return default + values = tuple(part.strip() for part in raw.split(",") if part.strip()) + return values or default + + +PCI_DSS_CRE_SIMILARITY_THRESHOLDS = _parse_float_tuple_env( + "PCI_DSS_CRE_SIMILARITY_THRESHOLDS", _DEFAULT_PCI_DSS_CRE_SIMILARITY_THRESHOLDS +) +PCI_BRIDGE_STANDARDS = _parse_str_tuple_env( + "PCI_DSS_BRIDGE_STANDARDS", _DEFAULT_PCI_BRIDGE_STANDARDS +) +PCI_BRIDGE_MIN_SIMILARITY = _parse_float_env( + "PCI_DSS_BRIDGE_MIN_SIMILARITY", _DEFAULT_PCI_BRIDGE_MIN_SIMILARITY +) + + +class PciDssLinkError(Exception): + """Raised when one or more PCI DSS controls cannot be linked to a CRE.""" + + +def pci_control_embedding_text(control: defs.Standard) -> str: + """Text used for PCI→CRE similarity (avoid full Standard repr JSON noise).""" + return "\n".join( + part.strip() + for part in (control.sectionID, control.section, control.description) + if part and str(part).strip() + ) + + +def best_cre_via_bridge_standard( + cache: db.Node_collection, + control_embedding: List[float], + standard_name: str, + *, + min_similarity: float = PCI_BRIDGE_MIN_SIMILARITY, +) -> Optional[defs.CRE]: + """Pick the best CRE linked to ``standard_name`` by node embedding similarity.""" + import numpy as np + from scipy import sparse + from sklearn.metrics.pairwise import cosine_similarity + + if not control_embedding: + return None + + embedding_array = sparse.csr_matrix( + np.array(control_embedding, dtype=np.float64).reshape(1, -1) + ) + best_similarity = -1.0 + best_cre: Optional[defs.CRE] = None + + for node in cache.get_nodes(name=standard_name) or []: + node_embedding = cache.get_embeddings_for_doc(node) + if not node_embedding: + continue + node_array = sparse.csr_matrix( + np.array(node_embedding, dtype=np.float64).reshape(1, -1) + ) + similarity = float(cosine_similarity(embedding_array, node_array)[0][0]) + if similarity < min_similarity or similarity <= best_similarity: + continue + linked_cres = cache.find_cres_of_node(node) + if not linked_cres: + continue + cre = cache.get_cre_by_db_id(linked_cres[0].id) + if cre: + best_similarity = similarity + best_cre = cre + + if best_cre: + logger.info( + "PCI DSS bridge match via %s (similarity %.3f)", + standard_name, + best_similarity, + ) + return best_cre + + +def resolve_cre_for_pci_control( + prompt: prompt_client.PromptHandler, + cache: db.Node_collection, + control_embedding: List[float], +) -> Optional[defs.CRE]: + """Resolve a CRE for one PCI control using staged similarity + bridge fallbacks.""" + for threshold in PCI_DSS_CRE_SIMILARITY_THRESHOLDS: + match = prompt.get_id_of_most_similar_cre_paginated( + control_embedding, similarity_threshold=threshold + ) + if match and match[0]: + cre = cache.get_cre_by_db_id(match[0]) + if cre: + logger.info( + "PCI DSS CRE similarity match %.3f (threshold %s)", + match[1], + threshold, + ) + return cre + + for standard_name in PCI_BRIDGE_STANDARDS: + cre = best_cre_via_bridge_standard(cache, control_embedding, standard_name) + if cre: + return cre + + standard_id = prompt.get_id_of_most_similar_node(control_embedding) + if standard_id: + nodes = cache.get_nodes(db_id=standard_id) + if nodes: + linked_cres = cache.find_cres_of_node(nodes[0]) + if linked_cres: + cre = cache.get_cre_by_db_id(linked_cres[0].id) + if cre: + logger.info( + "PCI DSS linked via global standard fallback (%s)", + nodes[0].name, + ) + return cre + return None + class PciDss(ParserInterface): name = "PCI DSS" @@ -70,6 +218,7 @@ def __parse( prompt = prompt_client.PromptHandler(cache) self._ensure_similarity_prereqs(cache, prompt) standard_entries = [] + unlinked_controls: list[str] = [] for row in pci_file.get(pci_file_tab): pci_control = defs.Standard( name=self.name, @@ -117,56 +266,40 @@ def __parse( f"Node {pci_control.todict()} already exists and has embeddings, skipping" ) - control_embeddings = prompt.get_text_embeddings(pci_control.__repr__()) + control_embeddings = prompt.get_text_embeddings( + pci_control_embedding_text(pci_control) + ) pci_control.embeddings = control_embeddings - pci_control.embeddings_text = pci_control.__repr__() - # these embeddings are different to the ones generated from --generate embeddings, this is because we want these embedding to include the optional "description" field, it is not a big difference and cosine similarity works reasonably accurately without it but good to have - cre = None - cre_id = prompt.get_id_of_most_similar_cre(control_embeddings) - if not cre_id: - logger.info( - f"could not find an appropriate CRE for pci {pci_control.section}, findings similarities with standards instead" - ) - standard_id = prompt.get_id_of_most_similar_node(control_embeddings) - if standard_id: - dbstandard = cache.get_nodes(db_id=standard_id) - if dbstandard: - logger.info( - "found an appropriate standard for pci %s, it is: %s", - pci_control.section, - dbstandard.section, - ) - cres = cache.find_cres_of_node(dbstandard) - if cres: - cre_id = cres[0].id - else: - logger.info( - "no standard record found for fallback standard id %s (pci section %s)", - standard_id, - pci_control.section, - ) - else: - logger.info( - "could not find a similar standard for pci %s; skipping fallback link", - pci_control.section, - ) - if cre_id: - cre = cache.get_cre_by_db_id(cre_id) - ctrl_copy = pci_control.shallow_copy() + pci_control.embeddings_text = pci_control_embedding_text(pci_control) + cre = resolve_cre_for_pci_control(prompt, cache, control_embeddings) pci_control.description = "" if cre: pci_control.add_link( defs.Link(document=cre, ltype=defs.LinkTypes.AutomaticallyLinkedTo) ) - pci_control.add_link( - defs.Link(ltype=defs.LinkTypes.AutomaticallyLinkedTo, document=cre) - ) logger.info(f"successfully stored {pci_control.__repr__()}") else: - logger.info( - f"stored pci control: {pci_control.__repr__()} but could not link it to any CRE reliably" + unlinked_controls.append( + f"{pci_control.sectionID}: {pci_control.section}" + ) + logger.error( + "PCI DSS control %s (%s) could not be linked to any CRE", + pci_control.sectionID, + pci_control.section, ) standard_entries.append(pci_control) + if unlinked_controls: + sample = unlinked_controls[:5] + extra = ( + f" (and {len(unlinked_controls) - len(sample)} more)" + if len(unlinked_controls) > len(sample) + else "" + ) + raise PciDssLinkError( + "PCI DSS import requires every control to link to a CRE; " + f"{len(unlinked_controls)} control(s) failed: " + f"{'; '.join(sample)}{extra}" + ) return standard_entries def parse_3_2(self, pci_file: Dict[str, Any], cache: db.Node_collection): diff --git a/application/utils/external_project_parsers/parsers/secure_headers.py b/application/utils/external_project_parsers/parsers/secure_headers.py index 384ef103e..4148f3943 100644 --- a/application/utils/external_project_parsers/parsers/secure_headers.py +++ b/application/utils/external_project_parsers/parsers/secure_headers.py @@ -1,12 +1,13 @@ # script to parse secure headers md files find the links to opencre.org and add the page to CRE -from pprint import pprint from typing import List -from application.database import db -from application.utils import git -from application.defs import cre_defs as defs +import logging import os import re from urllib.parse import urlparse, parse_qs + +from application.database import db +from application.defs import cre_defs as defs +from application.utils import git from application.utils.external_project_parsers import base_parser_defs from application.utils.external_project_parsers.base_parser_defs import ( ParserInterface, @@ -14,7 +15,21 @@ ) from application.prompt_client import prompt_client as prompt_client -# GENERIC Markdown file parser for self-contained links! when we have more projects using this setup add them in the list +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# GENERIC Markdown file parser + +# OWASP markdown may reference retired CRE ids; map to current OpenCRE ids. +LEGACY_CRE_ID_REMAP = { + # tab_bestpractices.md still links 403-005; corpus uses 743-110 for this topic. + "403-005": "743-110", +} + + +class SecureHeadersLinkError(Exception): + """Raised when a Secure Headers markdown CRE reference cannot be resolved.""" class SecureHeaders(ParserInterface): @@ -35,6 +50,28 @@ def entry(self, section: str, hyperlink: str, tags: List[str]) -> defs.Standard: hyperlink=hyperlink, ) + def resolve_cre_external_id( + self, cache: db.Node_collection, external_id: str + ) -> tuple[list[defs.CRE], str]: + candidates = [external_id] + remapped = LEGACY_CRE_ID_REMAP.get(external_id) + if remapped and remapped not in candidates: + candidates.append(remapped) + for candidate in candidates: + cres = cache.get_CREs(external_id=candidate) + if cres: + if candidate != external_id: + logger.info( + "Secure Headers remapped stale CRE id %s -> %s", + external_id, + candidate, + ) + return cres, candidate + raise SecureHeadersLinkError( + f"Secure Headers markdown references unknown CRE id {external_id!r}" + + (f" (also tried remap {remapped!r})" if remapped else "") + ) + def parse(self, cache: db.Node_collection, ph: prompt_client.PromptHandler): sh_repo = "https://github.com/owasp/www-project-secure-headers.git" file_path = "./" @@ -56,36 +93,40 @@ def register_headers(self, cache: db.Node_collection, repo, file_path, repo_path entries = [] for path, _, files in os.walk(repo.working_dir): for mdfile in files: + if not mdfile.endswith(".md"): + continue pth = os.path.join(path, mdfile) if not os.path.isfile(pth): continue - with open(pth) as mdf: - mdtext = mdf.read() + try: + with open(pth, encoding="utf-8") as mdf: + mdtext = mdf.read() + except UnicodeDecodeError: + logger.warning("Skipping non-UTF-8 markdown file: %s", pth) + continue - if "opencre.org" not in mdtext: - continue - links = re.finditer(cre_link, mdtext, re.MULTILINE) - for cre in links: - if cre: - parsed = urlparse(cre.group("url")) - creID = cre.group("creID") - queries = parse_qs(parsed.query) - name = queries.get("name") - section = queries.get("section") - link = queries.get("link") - cres = cache.get_CREs(external_id=creID) - cs = self.entry( - section=section[0] if section else "", - hyperlink=link[0] if link else "", - tags=[], + if "opencre.org" not in mdtext: + continue + links = re.finditer(cre_link, mdtext, re.MULTILINE) + for cre in links: + parsed = urlparse(cre.group("url")) + creID = cre.group("creID") + queries = parse_qs(parsed.query) + section = queries.get("section") + link = queries.get("link") + cres, _resolved_id = self.resolve_cre_external_id(cache, creID) + cs = self.entry( + section=section[0] if section else "", + hyperlink=link[0] if link else "", + tags=[], + ) + for dbcre in cres: + cs.add_link( + defs.Link( + document=dbcre, + ltype=defs.LinkTypes.AutomaticallyLinkedTo, ) - for dbcre in cres: - cs.add_link( - defs.Link( - document=dbcre, - ltype=defs.LinkTypes.AutomaticallyLinkedTo, - ) - ) + ) entries.append(cs) return entries diff --git a/application/utils/gap_analysis.py b/application/utils/gap_analysis.py index 8a320a462..da76390c4 100644 --- a/application/utils/gap_analysis.py +++ b/application/utils/gap_analysis.py @@ -22,6 +22,11 @@ } GAP_ANALYSIS_TIMEOUT = "129600s" # 36 hours +OPENCRE_STANDARD_NAME = "OpenCRE" +OPENCRE_OVERLAP_LINK_TYPES = ( + defs.LinkTypes.LinkedTo, + defs.LinkTypes.AutomaticallyLinkedTo, +) def make_resources_key(array: List[str]): @@ -86,6 +91,168 @@ def get_next_id(step, previous_id): return step["start"].id +def _link_type_to_path_relationship(ltype: defs.LinkTypes) -> str: + if ltype == defs.LinkTypes.AutomaticallyLinkedTo: + return "AUTOMATICALLY_LINKED_TO" + return "LINKED_TO" + + +def _opencre_overlap_link_sort_key(link: defs.Link) -> int: + if link.ltype == defs.LinkTypes.LinkedTo: + return 0 + if link.ltype == defs.LinkTypes.AutomaticallyLinkedTo: + return 1 + return 2 + + +def _build_direct_link_path( + start_document: defs.Document, + end_document: defs.Document, + *, + ltype: defs.LinkTypes = defs.LinkTypes.LinkedTo, +) -> Dict[str, Any]: + segment_start = start_document.shallow_copy() + if segment_start.doctype != defs.Credoctypes.CRE.value: + segment_start.id = "" + return { + "end": end_document.shallow_copy(), + "path": [ + { + "start": segment_start, + "end": end_document.shallow_copy(), + "relationship": _link_type_to_path_relationship(ltype), + "score": 0, + } + ], + "score": 0, + } + + +def _add_direct_link_result( + grouped_paths: Dict[str, Dict[str, Any]], + start_document: defs.Document, + end_document: defs.Document, + *, + ltype: defs.LinkTypes = defs.LinkTypes.LinkedTo, +) -> None: + shared_paths = grouped_paths.setdefault( + start_document.id, + { + "start": start_document.shallow_copy(), + "paths": {}, + "extra": 0, + }, + )["paths"] + path_key = end_document.id + if path_key in shared_paths: + return + shared_paths[path_key] = _build_direct_link_path( + start_document, end_document, ltype=ltype + ) + + +def build_direct_cre_overlap_map_analysis( + standards: List[str], + standards_hash: str, + collection: Any, +) -> Optional[Dict[str, Any]]: + """Compute one-step OpenCRE links (manual and automatic) for a standard pair.""" + if len(standards) < 2: + return None + + base_standard = standards[0] + compare_standard = standards[1] + base_is_opencre = base_standard == OPENCRE_STANDARD_NAME + compare_is_opencre = compare_standard == OPENCRE_STANDARD_NAME + if not base_is_opencre and not compare_is_opencre: + return None + + standard_name = compare_standard if base_is_opencre else base_standard + standard_nodes = collection.get_nodes(name=standard_name) + if not standard_nodes: + return None + + grouped_paths: Dict[str, Dict[str, Any]] = {} + for standard_node in standard_nodes: + cre_links = [ + link + for link in (standard_node.links or []) + if link.ltype in OPENCRE_OVERLAP_LINK_TYPES + and link.document.doctype == defs.Credoctypes.CRE.value + ] + for link in sorted(cre_links, key=_opencre_overlap_link_sort_key): + linked_document = link.document + if base_is_opencre: + _add_direct_link_result( + grouped_paths, + linked_document, + standard_node, + ltype=link.ltype, + ) + else: + _add_direct_link_result( + grouped_paths, + standard_node, + linked_document, + ltype=link.ltype, + ) + + if not grouped_paths: + return None + + result = {"result": grouped_paths} + collection.add_gap_analysis_result( + cache_key=standards_hash, ga_object=flask_json.dumps(result) + ) + return result + + +def opencre_direct_pairs(standard_names: List[str]) -> List[List[str]]: + """Directed OpenCRE pairs for every real standard name.""" + pairs: List[List[str]] = [] + for name in sorted({str(s).strip() for s in standard_names if str(s).strip()}): + if name == OPENCRE_STANDARD_NAME: + continue + pairs.append([OPENCRE_STANDARD_NAME, name]) + pairs.append([name, OPENCRE_STANDARD_NAME]) + return pairs + + +def missing_opencre_direct_pairs(collection: Any) -> List[List[str]]: + missing: List[List[str]] = [] + for pair in opencre_direct_pairs(collection.standards()): + cache_key = make_resources_key(pair) + if not collection.gap_analysis_exists(cache_key): + missing.append(pair) + return missing + + +def backfill_opencre_direct_pairs(collection: Any, *, refresh: bool = False) -> int: + """Populate SQL cache rows for OpenCRE map analysis pairs (manual + automatic links).""" + pairs = opencre_direct_pairs(collection.standards()) + if refresh: + todo = pairs + logger.info("OpenCRE direct GA backfill: refreshing all pairs=%s", len(todo)) + else: + todo = missing_opencre_direct_pairs(collection) + if not todo: + logger.info("OpenCRE direct GA backfill: no missing pairs") + return 0 + logger.info("OpenCRE direct GA backfill: missing_pairs=%s", len(todo)) + + written = 0 + for pair in todo: + cache_key = make_resources_key(pair) + if build_direct_cre_overlap_map_analysis(pair, cache_key, collection): + written += 1 + logger.info( + "OpenCRE direct GA backfill: wrote=%s remaining=%s", + written, + len(missing_opencre_direct_pairs(collection)), + ) + return written + + def perform(standards: List[str], database): return run_gap_pair(standards[0], standards[1], database) diff --git a/application/web/web_main.py b/application/web/web_main.py index 4049f8981..32f77679a 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -48,7 +48,7 @@ ITEMS_PER_PAGE = 20 -OPENCRE_STANDARD_NAME = "OpenCRE" +OPENCRE_STANDARD_NAME = gap_analysis.OPENCRE_STANDARD_NAME app = Blueprint( "web", @@ -298,116 +298,6 @@ def find_document_by_tag() -> Any: abort(404, "Tag does not exist") -def _get_opencre_documents(collection: db.Node_collection) -> list[defs.CRE]: - return [ - collection.get_CREs(internal_id=cre.id)[0] - for cre in collection.session.query(db.CRE).all() - ] - - -def _get_map_analysis_documents( - standard: str, collection: db.Node_collection -) -> list[defs.Document]: - if standard == OPENCRE_STANDARD_NAME: - return _get_opencre_documents(collection) - return collection.get_nodes(name=standard) - - -def _build_direct_link_path( - start_document: defs.Document, end_document: defs.Document -) -> dict[str, Any]: - segment_start = start_document.shallow_copy() - # The current gap-analysis popup mutates non-CRE row ids during display - # before it resolves the one-step direct path. Keep this direct-link fast - # path compatible by mirroring that display-only shape in the segment start. - if segment_start.doctype != defs.Credoctypes.CRE.value: - segment_start.id = "" - return { - "end": end_document.shallow_copy(), - "path": [ - { - "start": segment_start, - "end": end_document.shallow_copy(), - "relationship": "LINKED_TO", - "score": 0, - } - ], - "score": 0, - } - - -def _make_direct_link_path_key(end_document: defs.Document) -> str: - return end_document.id - - -def _add_direct_link_result( - grouped_paths: dict[str, dict[str, Any]], - start_document: defs.Document, - end_document: defs.Document, -) -> None: - shared_paths = grouped_paths.setdefault( - start_document.id, - { - "start": start_document.shallow_copy(), - "paths": {}, - "extra": 0, - }, - )["paths"] - shared_paths.setdefault( - _make_direct_link_path_key(end_document), - _build_direct_link_path(start_document, end_document), - ) - - -def _build_direct_cre_overlap_map_analysis( - standards: list[str], - standards_hash: str, - collection: db.Node_collection, -) -> dict[str, Any] | None: - if len(standards) < 2: - return None - - base_standard = standards[0] - compare_standard = standards[1] - base_nodes = _get_map_analysis_documents(base_standard, collection) - compare_nodes = _get_map_analysis_documents(compare_standard, collection) - if not base_nodes or not compare_nodes: - return None - - base_is_opencre = base_standard == OPENCRE_STANDARD_NAME - opencre_nodes = base_nodes if base_is_opencre else compare_nodes - standard_nodes = compare_nodes if base_is_opencre else base_nodes - - standard_nodes_by_id = { - standard_node.id: standard_node for standard_node in standard_nodes - } - direct_pairs: list[tuple[defs.CRE, defs.Document]] = [] - for opencre_node in opencre_nodes: - for link in opencre_node.links: - if link.ltype != defs.LinkTypes.LinkedTo: - continue - standard_node = standard_nodes_by_id.get(link.document.id) - if not standard_node: - continue - direct_pairs.append((opencre_node, standard_node)) - - grouped_paths: dict[str, dict[str, Any]] = {} - for opencre_node, standard_node in direct_pairs: - if base_is_opencre: - _add_direct_link_result(grouped_paths, opencre_node, standard_node) - else: - _add_direct_link_result(grouped_paths, standard_node, opencre_node) - - if not grouped_paths: - return None - - result = {"result": grouped_paths} - collection.add_gap_analysis_result( - cache_key=standards_hash, ga_object=flask_json.dumps(result) - ) - return result - - @app.route("/rest/v1/map_analysis", methods=["GET"]) def map_analysis() -> Any: standards = request.args.getlist("standard") @@ -420,9 +310,17 @@ def map_analysis() -> Any: standards = standards[:2] standards_hash = gap_analysis.make_resources_key(standards) - # ----- PR #825: OpenCRE fast path ----- + # ----- PR #825: OpenCRE fast path (SQL cache only on Heroku) ----- if OPENCRE_STANDARD_NAME in standards: - direct_gap_analysis = _build_direct_cre_overlap_map_analysis( + if database.gap_analysis_exists(standards_hash): + cached = database.get_gap_analysis_result(cache_key=standards_hash) + if cached: + parsed = json.loads(cached) + if "result" in parsed: + return jsonify({"result": parsed.get("result")}) + if os.environ.get("HEROKU"): + abort(404, "No such Cache") + direct_gap_analysis = gap_analysis.build_direct_cre_overlap_map_analysis( standards, standards_hash, database ) if direct_gap_analysis: diff --git a/cre.py b/cre.py index 80dd48617..559e595e4 100644 --- a/cre.py +++ b/cre.py @@ -224,6 +224,11 @@ def main() -> None: default="", help="preload map analysis for all possible 2 standards combinations, use target url as an OpenCRE base", ) + parser.add_argument( + "--ga_backfill_opencre_direct", + action="store_true", + help="refresh OpenCRE map-analysis cache rows (manual + automatic CRE links)", + ) parser.add_argument( "--ga_backfill_missing", action="store_true", diff --git a/scripts/compute_pci_dss_cre_mappings.py b/scripts/compute_pci_dss_cre_mappings.py new file mode 100644 index 000000000..eb8f5e7a8 --- /dev/null +++ b/scripts/compute_pci_dss_cre_mappings.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +""" +Compute PCI DSS v4 control → CRE mappings using Gemini embeddings + staged similarity. + +Reads the public PCI DSS spreadsheet CSV, embeds each control, and resolves CRE links +using the same logic as application/utils/external_project_parsers/parsers/pci_dss.py. + +Usage: + python scripts/compute_pci_dss_cre_mappings.py \\ + --cache-file standards_cache.sqlite \\ + --output data/pci_dss_cre_mappings.json +""" + +from __future__ import annotations + +import argparse +import csv +import io +import json +import logging +import os +import sys +import time +import urllib.request +from typing import Any, Dict, List, Optional + +try: + from dotenv import load_dotenv + + load_dotenv() +except ImportError: + pass + +# Repo root on sys.path when invoked as a script. +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from application.cmd.cre_main import db_connect # noqa: E402 +from application.defs import cre_defs as defs # noqa: E402 +from application.prompt_client import prompt_client # noqa: E402 +from application.utils.external_project_parsers.parsers.pci_dss import ( # noqa: E402 + PCI_BRIDGE_STANDARDS, + PCI_DSS_CRE_SIMILARITY_THRESHOLDS, + best_cre_via_bridge_standard, + pci_control_embedding_text, +) + +PCI_SHEET_CSV_URL = ( + "https://docs.google.com/spreadsheets/d/" + "18weo-qbik_C7SdYq7FSP2OMgUmsWdWWI1eaXcAfMz8I/export?format=csv" +) + +logger = logging.getLogger(__name__) + + +def _configure_llm_env() -> None: + embed_model = os.environ.get("CRE_EMBED_MODEL") + if not embed_model: + vertex_embed = os.environ.get( + "VERTEX_EMBED_CONTENT_MODEL", "gemini-embedding-001" + ) + os.environ["CRE_EMBED_MODEL"] = f"gemini/{vertex_embed}" + os.environ.setdefault("CRE_EMBED_EXPECTED_DIM", "3072") + os.environ.setdefault("CRE_VALIDATE_EMBED_DIM_ON_INIT", "0") + + +def fetch_pci_rows(url: str = PCI_SHEET_CSV_URL) -> List[Dict[str, str]]: + with urllib.request.urlopen(url, timeout=120) as resp: + raw = resp.read().decode("utf-8-sig") + reader = csv.DictReader(io.StringIO(raw)) + rows = [row for row in reader if (row.get("PCI DSS ID") or "").strip()] + if not rows: + raise RuntimeError(f"no PCI rows found at {url}") + return rows + + +def resolve_with_method( + prompt: prompt_client.PromptHandler, + cache, + control_embedding: List[float], +) -> tuple[Optional[defs.CRE], str, Optional[float]]: + for threshold in PCI_DSS_CRE_SIMILARITY_THRESHOLDS: + match = prompt.get_id_of_most_similar_cre_paginated( + control_embedding, similarity_threshold=threshold + ) + if match and match[0]: + cre = cache.get_cre_by_db_id(match[0]) + if cre: + return cre, f"cre_similarity>={threshold}", float(match[1]) + + for standard_name in PCI_BRIDGE_STANDARDS: + cre = best_cre_via_bridge_standard(cache, control_embedding, standard_name) + if cre: + return cre, f"bridge:{standard_name}", None + + standard_id = prompt.get_id_of_most_similar_node(control_embedding) + if standard_id: + nodes = cache.get_nodes(db_id=standard_id) + if nodes: + linked_cres = cache.find_cres_of_node(nodes[0]) + if linked_cres: + cre = cache.get_cre_by_db_id(linked_cres[0].id) + if cre: + return cre, f"global_standard:{nodes[0].name}", None + return None, "unlinked", None + + +def compute_mappings( + cache, + rows: List[Dict[str, str]], + *, + limit: Optional[int] = None, +) -> List[Dict[str, Any]]: + prompt = prompt_client.PromptHandler(cache) + mappings: List[Dict[str, Any]] = [] + total = len(rows) if limit is None else min(limit, len(rows)) + + for index, row in enumerate(rows[:total], start=1): + section_id = str(row.get("PCI DSS ID", "")).strip() + section = str(row.get("Defined Approach Requirements", "")).strip() + description = str( + row.get("Requirement Description", "") or row.get("Guidance", "") + ).strip() + control = defs.Standard( + name="PCI DSS", + sectionID=section_id, + section=section, + description=description, + version="4", + ) + if control.section.startswith(control.sectionID): + control.section = control.section[len(control.sectionID) :].strip() + + embedding_text = pci_control_embedding_text(control) + t0 = time.time() + embedding = prompt.get_text_embeddings(embedding_text) + cre, method, similarity = resolve_with_method(prompt, cache, embedding) + elapsed = time.time() - t0 + + entry: Dict[str, Any] = { + "pci_dss_id": section_id, + "section": control.section, + "cre_id": cre.id if cre else None, + "cre_name": cre.name if cre else None, + "method": method, + "similarity": similarity, + "elapsed_seconds": round(elapsed, 2), + } + mappings.append(entry) + status = cre.id if cre else "UNLINKED" + logger.info("[%s/%s] %s -> %s (%s)", index, total, section_id, status, method) + + return mappings + + +def main() -> int: + parser = argparse.ArgumentParser(description="Compute PCI DSS → CRE mappings") + parser.add_argument( + "--cache-file", + default=os.environ.get( + "CRE_CACHE_FILE", os.path.join(_REPO_ROOT, "standards_cache.sqlite") + ), + ) + parser.add_argument( + "--output", + default=os.path.join(_REPO_ROOT, "data", "pci_dss_cre_mappings.json"), + ) + parser.add_argument("--sheet-url", default=PCI_SHEET_CSV_URL) + parser.add_argument( + "--limit", type=int, default=None, help="process only first N controls" + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") + _configure_llm_env() + + rows = fetch_pci_rows(args.sheet_url) + logger.info("loaded %s PCI DSS controls from spreadsheet", len(rows)) + + cache = db_connect(path=args.cache_file) + mappings = compute_mappings(cache, rows, limit=args.limit) + + linked = [m for m in mappings if m["cre_id"]] + unlinked = [m for m in mappings if not m["cre_id"]] + summary = { + "total": len(mappings), + "linked": len(linked), + "unlinked": len(unlinked), + "unlinked_ids": [m["pci_dss_id"] for m in unlinked], + "thresholds": list(PCI_DSS_CRE_SIMILARITY_THRESHOLDS), + "bridge_standards": list(PCI_BRIDGE_STANDARDS), + "embed_model": os.environ.get("CRE_EMBED_MODEL"), + } + payload = {"summary": summary, "mappings": mappings} + + os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) + with open(args.output, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2) + handle.write("\n") + + logger.info( + "wrote %s mappings to %s (%s linked, %s unlinked)", + len(mappings), + args.output, + len(linked), + len(unlinked), + ) + if unlinked: + logger.error("unlinked controls: %s", ", ".join(summary["unlinked_ids"][:10])) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/sync_gap_analysis_table.py b/scripts/sync_gap_analysis_table.py new file mode 100644 index 000000000..5f461fd95 --- /dev/null +++ b/scripts/sync_gap_analysis_table.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +"""Copy ``gap_analysis_results`` rows between databases.""" + +from __future__ import annotations + +import argparse +import sqlite3 +import sys +import urllib.parse +from typing import List, Optional, Sequence, Tuple + +import psycopg2 +from psycopg2 import extras + + +def _normalize_pg_url(url: str) -> str: + if url.startswith("postgres://"): + return "postgresql://" + url[len("postgres://") :] + return url + + +def _pg_host_is_loopback(url: str) -> bool: + p = urllib.parse.urlparse(_normalize_pg_url(url)) + h = (p.hostname or "").lower() + return h in ("127.0.0.1", "localhost", "::1") or h == "" + + +def _fetch_sqlite_rows(path: str) -> List[Tuple[str, Optional[str]]]: + conn = sqlite3.connect(path) + cur = conn.execute("SELECT cache_key, ga_object FROM gap_analysis_results") + rows = [(str(k), None if v is None else str(v)) for k, v in cur.fetchall()] + conn.close() + return rows + + +def _replace_postgres_rows( + pg_url: str, rows: Sequence[Tuple[str, Optional[str]]] +) -> None: + conn = psycopg2.connect(_normalize_pg_url(pg_url)) + conn.autocommit = False + try: + cur = conn.cursor() + cur.execute("DELETE FROM public.gap_analysis_results") + if rows: + extras.execute_batch( + cur, + "INSERT INTO public.gap_analysis_results (cache_key, ga_object) VALUES (%s, %s)", + list(rows), + page_size=200, + ) + conn.commit() + cur.close() + finally: + conn.close() + + +def main() -> int: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--from-sqlite", required=True, metavar="PATH") + p.add_argument("--to-postgres", required=True, metavar="URL") + p.add_argument("--require-local-destination", action="store_true") + p.add_argument("--allow-nonloopback-destination", action="store_true") + args = p.parse_args() + + if args.require_local_destination and not _pg_host_is_loopback(args.to_postgres): + print("error: destination is not loopback", file=sys.stderr) + return 2 + if ( + not _pg_host_is_loopback(args.to_postgres) + and not args.allow_nonloopback_destination + ): + print( + "error: remote destination requires --allow-nonloopback-destination", + file=sys.stderr, + ) + return 2 + + rows = _fetch_sqlite_rows(args.from_sqlite) + print(f"read {len(rows)} row(s) from {args.from_sqlite!r}") + _replace_postgres_rows(args.to_postgres, rows) + print(f"wrote {len(rows)} row(s) to postgres") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())