Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,138 @@ def process_data(data):
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorTestUtils.assertDeterministic(detector, ctx);
}

@Test
void detectsSharedTask() {
String code = """
@shared_task
def cleanup():
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(2, result.nodes().size());
var queueNode = result.nodes().stream()
.filter(n -> n.getKind() == NodeKind.QUEUE).findFirst().orElseThrow();
assertEquals("celery", queueNode.getProperties().get("broker"));
}

@Test
void taskQueueNodeHasTaskNameProperty() {
String code = """
@app.task
def process(data):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorResult result = detector.detect(ctx);

var queueNode = result.nodes().stream()
.filter(n -> n.getKind() == NodeKind.QUEUE).findFirst().orElseThrow();
assertEquals("process", queueNode.getProperties().get("task_name"));
assertEquals("process", queueNode.getProperties().get("function"));
}

@Test
void taskMethodNodeHasFqn() {
String code = """
@app.task
def my_task():
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorResult result = detector.detect(ctx);

var methodNode = result.nodes().stream()
.filter(n -> n.getKind() == NodeKind.METHOD).findFirst().orElseThrow();
assertNotNull(methodNode.getFqn());
assertTrue(methodNode.getFqn().contains("my_task"));
}

@Test
void consumesEdgeGoesFromMethodToQueue() {
String code = """
@app.task
def my_task():
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorResult result = detector.detect(ctx);

var consumesEdge = result.edges().stream()
.filter(e -> e.getKind() == EdgeKind.CONSUMES).findFirst().orElseThrow();
assertNotNull(consumesEdge.getSourceId());
assertTrue(consumesEdge.getSourceId().startsWith("method:"));
}

@Test
void detectsApplyAsync() {
String code = """
send_email.apply_async(args=["user@test.com"], countdown=60)
""";
DetectorContext ctx = DetectorTestUtils.contextFor("views.py", "python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(1, result.edges().size());
assertEquals(EdgeKind.PRODUCES, result.edges().get(0).getKind());
}

@Test
void detectsSignatureCall() {
String code = """
task_sig = my_task.s(arg1)
""";
DetectorContext ctx = DetectorTestUtils.contextFor("views.py", "python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(1, result.edges().size());
assertEquals(EdgeKind.PRODUCES, result.edges().get(0).getKind());
}

@Test
void multipleTaskDefinitions() {
String code = """
@app.task
def task_a():
pass

@shared_task
def task_b():
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorResult result = detector.detect(ctx);

long queueCount = result.nodes().stream()
.filter(n -> n.getKind() == NodeKind.QUEUE).count();
long methodCount = result.nodes().stream()
.filter(n -> n.getKind() == NodeKind.METHOD).count();
assertEquals(2, queueCount);
assertEquals(2, methodCount);
}

@Test
void noMatchOnEmptyContent() {
DetectorContext ctx = DetectorTestUtils.contextFor("python", "");
DetectorResult result = detector.detect(ctx);

assertEquals(0, result.nodes().size());
assertEquals(0, result.edges().size());
}

@Test
void explicitTaskNameOverridesFunctionName() {
String code = """
@app.task(name='myapp.tasks.send_notification')
def notify_user(user_id):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("tasks.py", "python", code);
DetectorResult result = detector.detect(ctx);

var queueNode = result.nodes().stream()
.filter(n -> n.getKind() == NodeKind.QUEUE).findFirst().orElseThrow();
assertEquals("celery:myapp.tasks.send_notification", queueNode.getLabel());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,149 @@ def view2(request):
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorTestUtils.assertDeterministic(detector, ctx);
}

@Test
void detectsPermissionRequiredMixin() {
String code = """
class AdminView(PermissionRequiredMixin, View):
permission_required = 'app.can_admin'
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(1, result.nodes().size());
assertEquals("PermissionRequiredMixin", result.nodes().get(0).getProperties().get("mixin"));
assertEquals("AdminView", result.nodes().get(0).getProperties().get("class_name"));
}

@Test
void detectsUserPassesTestMixin() {
String code = """
class StaffView(UserPassesTestMixin, View):
def test_func(self):
return self.request.user.is_staff
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(1, result.nodes().size());
assertEquals("UserPassesTestMixin", result.nodes().get(0).getProperties().get("mixin"));
}

@Test
void loginRequiredHasAuthType() {
String code = """
@login_required
def secure_view(request):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

assertEquals("django", result.nodes().get(0).getProperties().get("auth_type"));
assertEquals(true, result.nodes().get(0).getProperties().get("auth_required"));
}

@Test
void loginRequiredHasAnnotations() {
String code = """
@login_required
def secured(request):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

var node = result.nodes().get(0);
assertTrue(node.getAnnotations().contains("@login_required"));
}

@Test
void permissionRequiredHasPermissionsProperty() {
String code = """
@permission_required("myapp.view_report")
def report_view(request):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

@SuppressWarnings("unchecked")
List<String> perms = (List<String>) result.nodes().get(0).getProperties().get("permissions");
assertNotNull(perms);
assertFalse(perms.isEmpty());
assertEquals("myapp.view_report", perms.get(0));
}

@Test
void userPassesTestHasTestFunctionProperty() {
String code = """
@user_passes_test(lambda u: u.is_active)
def restricted_view(request):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

// test_function should be set from arg
assertNotNull(result.nodes().get(0).getProperties().get("test_function"));
}

@Test
void mixinAnnotationFormat() {
String code = """
class SecureList(LoginRequiredMixin, ListView):
model = Item
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

var node = result.nodes().get(0);
assertTrue(node.getAnnotations().stream().anyMatch(a -> a.contains("LoginRequiredMixin")));
}

@Test
void noMatchOnEmptyContent() {
DetectorContext ctx = DetectorTestUtils.contextFor("python", "");
DetectorResult result = detector.detect(ctx);

assertEquals(0, result.nodes().size());
}

@Test
void multipleDecoratorsCapturedSeparately() {
String code = """
@login_required
def view_a(request):
pass

@login_required
def view_b(request):
pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(2, result.nodes().size());
}

@Test
void allAuthTypeIsDjango() {
String code = """
@login_required
def v1(request): pass

@permission_required("x.y")
def v2(request): pass

@user_passes_test(lambda u: True)
def v3(request): pass
""";
DetectorContext ctx = DetectorTestUtils.contextFor("python", code);
DetectorResult result = detector.detect(ctx);

assertEquals(3, result.nodes().size());
assertTrue(result.nodes().stream()
.allMatch(n -> "django".equals(n.getProperties().get("auth_type"))));
}
}
Loading
Loading