Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,9 @@ backfill-gap-analysis-sync:
python cre.py --cache_file "$$CRE_CACHE_FILE" --populate_neo4j_db && \
python cre.py --cache_file "$$CRE_CACHE_FILE" --ga_backfill_missing --ga_backfill_no_queue

backfill-opencre-ga:
@[ -d "./.venv" ] && . ./.venv/bin/activate || ([ -d "./venv" ] && . ./venv/bin/activate); \
export FLASK_APP="$(CURDIR)/cre.py"; \
python cre.py --cache_file "$${CRE_CACHE_FILE:-$(CURDIR)/standards_cache.sqlite}" --ga_backfill_opencre_direct

all: clean lint test dev dev-run
7 changes: 6 additions & 1 deletion application/cmd/cre_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,14 @@ def backfill_gap_analysis_only(
if os.environ.get("CRE_NO_NEO4J") != "1":
populate_neo4j_db(db_connection_str)

gap_analysis.backfill_opencre_direct_pairs(collection, refresh=True)

missing = _missing_ga_pairs(collection)
if max_pairs > 0:
missing = missing[:max_pairs]
total = len(missing)
if total == 0:
logger.info("GA backfill: no missing pairs")
logger.info("GA backfill: no missing neo4j pairs")
return

logger.info(
Expand Down Expand Up @@ -952,6 +954,9 @@ def run(args: argparse.Namespace) -> None: # pragma: no cover

if args.preload_map_analysis_target_url:
gap_analysis.preload(target_url=args.preload_map_analysis_target_url)
if getattr(args, "ga_backfill_opencre_direct", False):
collection = db_connect(path=args.cache_file)
gap_analysis.backfill_opencre_direct_pairs(collection, refresh=True)
if getattr(args, "ga_backfill_missing", False):
backfill_gap_analysis_only(
args.cache_file,
Expand Down
89 changes: 89 additions & 0 deletions application/tests/opencre_gap_analysis_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
import unittest
from unittest.mock import Mock, patch

from application import create_app, sqla # type: ignore
from application.database import db
from application.defs import cre_defs as defs
from application.utils.gap_analysis import (
OPENCRE_STANDARD_NAME,
backfill_opencre_direct_pairs,
make_resources_key,
)


class TestOpencreGapAnalysis(unittest.TestCase):
def tearDown(self) -> None:
sqla.session.remove()
sqla.drop_all()
self.app_context.pop()

def setUp(self) -> None:
self.app = create_app(mode="test")
self.app_context = self.app.app_context()
self.app_context.push()
sqla.create_all()
self.collection = db.Node_collection()

def test_backfill_populates_secure_headers_pair_from_auto_linked_nodes(
self,
) -> None:
cre = self.collection.add_cre(
defs.CRE(
id="636-347",
name="HTTP security headers",
description="",
)
)
header_node = self.collection.add_node(
defs.Standard(
name="Secure Headers",
section="Prevent information disclosure via HTTP headers",
hyperlink="https://owasp.org/example",
)
)
self.collection.add_link(
cre=cre,
node=header_node,
ltype=defs.LinkTypes.AutomaticallyLinkedTo,
)

written = backfill_opencre_direct_pairs(self.collection, refresh=True)
cache_key = make_resources_key([OPENCRE_STANDARD_NAME, "Secure Headers"])

self.assertGreaterEqual(written, 1)
self.assertTrue(self.collection.gap_analysis_exists(cache_key))
payload = json.loads(self.collection.get_gap_analysis_result(cache_key))
self.assertIn("636-347", payload["result"])
path = next(iter(payload["result"]["636-347"]["paths"].values()))
self.assertEqual("AUTOMATICALLY_LINKED_TO", path["path"][0]["relationship"])

@patch(
"application.utils.gap_analysis.build_direct_cre_overlap_map_analysis",
return_value={"result": {"x": {}}},
)
def test_backfill_refresh_recomputes_cached_pairs(self, build_mock: Mock) -> None:
collection = Mock()
collection.standards.return_value = ["ASVS"]

backfill_opencre_direct_pairs(collection, refresh=False)
build_mock.assert_not_called()

backfill_opencre_direct_pairs(collection, refresh=True)
self.assertEqual(2, build_mock.call_count)

@patch(
"application.utils.gap_analysis.build_direct_cre_overlap_map_analysis",
return_value={"result": {"x": {}}},
)
def test_backfill_missing_only_skips_when_cache_exists(
self, build_mock: Mock
) -> None:
collection = Mock()
collection.standards.return_value = ["ASVS"]
collection.gap_analysis_exists.return_value = True

written = backfill_opencre_direct_pairs(collection, refresh=False)

self.assertEqual(0, written)
build_mock.assert_not_called()
199 changes: 182 additions & 17 deletions application/tests/pci_dss_parser_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,79 @@
import unittest
from unittest.mock import Mock, patch

from application.utils.external_project_parsers.parsers.pci_dss import PciDss
from application.defs import cre_defs as defs
from application.utils.external_project_parsers.parsers import pci_dss as pci_mod
from application.utils.external_project_parsers.parsers.pci_dss import (
PciDss,
PciDssLinkError,
best_cre_via_bridge_standard,
pci_control_embedding_text,
resolve_cre_for_pci_control,
)


class TestPciDssLinking(unittest.TestCase):
def test_pci_control_embedding_text_uses_id_section_and_description(self) -> None:
control = defs.Standard(
name="PCI DSS",
sectionID="1.2.3",
section="Requirement title",
description="Longer requirement body",
)
text = pci_control_embedding_text(control)
self.assertIn("1.2.3", text)
self.assertIn("Requirement title", text)
self.assertIn("Longer requirement body", text)
self.assertNotIn("family:standard", text)

def test_resolve_cre_uses_paginated_cre_match_first(self) -> None:
cache = Mock()
linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="")
cache.get_cre_by_db_id.return_value = linked_cre
prompt = Mock()
prompt.get_id_of_most_similar_cre_paginated.return_value = ("cre-db-id", 0.82)

cre = resolve_cre_for_pci_control(prompt, cache, [0.1, 0.2])

self.assertEqual(linked_cre, cre)
prompt.get_id_of_most_similar_cre_paginated.assert_called()
prompt.get_id_of_most_similar_node.assert_not_called()

def test_resolve_cre_falls_back_to_bridge_standard(self) -> None:
cache = Mock()
prompt = Mock()
prompt.get_id_of_most_similar_cre_paginated.return_value = (None, None)
bridge_cre = defs.CRE(id="999-001", name="Bridge CRE", description="")

with patch.object(pci_mod, "PCI_BRIDGE_STANDARDS", ("S1", "S2")), patch.object(
pci_mod, "best_cre_via_bridge_standard", side_effect=[None, bridge_cre]
) as bridge_mock:
cre = resolve_cre_for_pci_control(prompt, cache, [0.1, 0.2])

self.assertEqual(bridge_cre, cre)
self.assertEqual(2, bridge_mock.call_count)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def test_best_cre_via_bridge_standard_picks_highest_similarity_linked_node(
self,
) -> None:
cache = Mock()
low_node = defs.Standard(name="NIST 800-53 v5", section="low", sectionID="a")
high_node = defs.Standard(name="NIST 800-53 v5", section="high", sectionID="b")
low_cre = defs.CRE(id="111-111", name="Low", description="")
high_cre = defs.CRE(id="222-222", name="High", description="")
cache.get_nodes.return_value = [low_node, high_node]
cache.get_embeddings_for_doc.side_effect = [[0.0, 1.0], [1.0, 0.0]]
cache.find_cres_of_node.side_effect = [
[Mock(id="low-db")],
[Mock(id="high-db")],
]
cache.get_cre_by_db_id.side_effect = [low_cre, high_cre]

cre = best_cre_via_bridge_standard(
cache, [1.0, 0.0], "NIST 800-53 v5", min_similarity=0.0
)

self.assertEqual(high_cre, cre)


class TestPciDssParser(unittest.TestCase):
Expand All @@ -15,14 +87,11 @@ def test_parse_skips_standard_fallback_when_no_standard_id(

cache = Mock()
cache.get_nodes.return_value = None
cache.find_cres_of_node.return_value = []
cache.get_cre_by_db_id.return_value = None
cache.get_embeddings_by_doc_type.return_value = {}
linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="")
cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]}

prompt = Mock()
prompt.get_text_embeddings.return_value = [0.1, 0.2]
prompt.get_id_of_most_similar_cre.return_value = None
prompt.get_id_of_most_similar_node.return_value = None
prompt_handler_mock.return_value = prompt

pci_file = {
Expand All @@ -35,19 +104,115 @@ def test_parse_skips_standard_fallback_when_no_standard_id(
]
}

out = parser.parse_4(pci_file=pci_file, cache=cache)
with patch.object(
pci_mod, "resolve_cre_for_pci_control", return_value=linked_cre
):
out = parser.parse_4(pci_file=pci_file, cache=cache)

self.assertEqual(1, len(out))
self.assertEqual(1, cache.get_nodes.call_count)
self.assertEqual(
{
"name": "PCI DSS",
"section": "Test requirement text",
"sectionID": "1.1.1",
},
cache.get_nodes.call_args.kwargs,
)
prompt.generate_embeddings_for.assert_called_once()
self.assertEqual(1, len(out[0].links))
prompt.generate_embeddings_for.assert_not_called()

@patch(
"application.utils.external_project_parsers.parsers.pci_dss.prompt_client.PromptHandler"
)
def test_parse_raises_when_control_cannot_be_linked(self, prompt_handler_mock):
parser = PciDss()

cache = Mock()
cache.get_nodes.return_value = None
cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]}

prompt = Mock()
prompt.get_text_embeddings.return_value = [0.1, 0.2]
prompt_handler_mock.return_value = prompt

pci_file = {
"Original Content": [
{
"Defined Approach Requirements": "Test requirement text",
"PCI DSS ID": "1.1.1",
"Requirement Description": "desc",
}
]
}

with patch.object(pci_mod, "resolve_cre_for_pci_control", return_value=None):
with self.assertRaises(PciDssLinkError):
parser.parse_4(pci_file=pci_file, cache=cache)

@patch(
"application.utils.external_project_parsers.parsers.pci_dss.prompt_client.PromptHandler"
)
def test_parse_raises_when_any_control_in_batch_is_unlinked(
self, prompt_handler_mock
):
parser = PciDss()
cache = Mock()
cache.get_nodes.return_value = None
linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="")
cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]}

prompt = Mock()
prompt.get_text_embeddings.return_value = [0.1, 0.2]
prompt_handler_mock.return_value = prompt

pci_file = {
"Original Content": [
{
"Defined Approach Requirements": "Linked requirement",
"PCI DSS ID": "1.1.1",
"Requirement Description": "desc",
},
{
"Defined Approach Requirements": "Unlinked requirement",
"PCI DSS ID": "1.1.2",
"Requirement Description": "desc",
},
]
}

with patch.object(
pci_mod,
"resolve_cre_for_pci_control",
side_effect=[linked_cre, None],
):
with self.assertRaisesRegex(PciDssLinkError, "1 control\\(s\\) failed"):
parser.parse_4(pci_file=pci_file, cache=cache)

@patch(
"application.utils.external_project_parsers.parsers.pci_dss.prompt_client.PromptHandler"
)
def test_parse_adds_single_automatic_link_per_control(self, prompt_handler_mock):
parser = PciDss()
cache = Mock()
cache.get_nodes.return_value = None
linked_cre = defs.CRE(id="123-456", name="Linked CRE", description="")
cache.get_embeddings_by_doc_type.return_value = {"cre-1": [0.1]}

prompt = Mock()
prompt.get_text_embeddings.return_value = [0.1, 0.2]
prompt_handler_mock.return_value = prompt

pci_file = {
"Original Content": [
{
"Defined Approach Requirements": "Test requirement text",
"PCI DSS ID": "1.1.1",
"Requirement Description": "desc",
}
]
}

with patch.object(
pci_mod, "resolve_cre_for_pci_control", return_value=linked_cre
):
out = parser.parse_4(pci_file=pci_file, cache=cache)

self.assertEqual(1, len(out))
self.assertEqual(1, len(out[0].links))
self.assertEqual(defs.LinkTypes.AutomaticallyLinkedTo, out[0].links[0].ltype)
self.assertEqual("123-456", out[0].links[0].document.id)


if __name__ == "__main__":
Expand Down
Loading
Loading