From dd399c554d5a0307721a044e0ab316e2296beca0 Mon Sep 17 00:00:00 2001 From: Marek Lipert Date: Sat, 30 May 2026 14:43:58 +0200 Subject: [PATCH] Implement local-fn. --- backend/codegen/CodeGen.cpp | 10 +- backend/codegen/CodeGen.h | 3 +- backend/codegen/ops/FnNode.cpp | 38 ++++-- backend/codegen/ops/LocalNode.cpp | 6 +- backend/tests/codegen/FnNode_test.cpp | 159 ++++++++++++++++++++++++++ 5 files changed, 204 insertions(+), 12 deletions(-) diff --git a/backend/codegen/CodeGen.cpp b/backend/codegen/CodeGen.cpp index 7460d2b2..84bbbf2b 100644 --- a/backend/codegen/CodeGen.cpp +++ b/backend/codegen/CodeGen.cpp @@ -217,7 +217,8 @@ std::string CodeGen::generateInstanceCallBridge( llvm::Function *CodeGen::generateBaselineMethod( const FnMethodNode &method, - const std::vector> &captureInfo) { + const std::vector> &captureInfo, + const std::string &selfBindingName) { CLJ_ASSERT(TSContext != nullptr, "Codegen was moved"); std::string funcName = "fn_"; @@ -298,6 +299,13 @@ llvm::Function *CodeGen::generateBaselineMethod( variableTypesBindingsStack.push(); functionMetricsStack.push_back({0, false}); + if (!selfBindingName.empty()) { + Value *selfPtr = types.getFrameSelfPtr(Builder, framePtr); + Value *selfVal = Builder.CreateLoad(types.RT_valueTy, selfPtr, "self"); + variableBindingStack.set(selfBindingName, TypedValue(ObjectTypeSet(functionType, false), selfVal)); + variableTypesBindingsStack.set(selfBindingName, ObjectTypeSet(functionType, false)); + } + std::string loopId = method.loopid(); if (!loopId.empty()) { recurTargets[loopId] = {F, framePtr, method.fixedarity(), diff --git a/backend/codegen/CodeGen.h b/backend/codegen/CodeGen.h index 11cf1666..993a1f35 100644 --- a/backend/codegen/CodeGen.h +++ b/backend/codegen/CodeGen.h @@ -132,7 +132,8 @@ class CodeGen { llvm::Function *generateBaselineMethod( const FnMethodNode &method, - const std::vector> &captureInfo); + const std::vector> &captureInfo, + const std::string &selfBindingName = ""); TypedValue codegen(const Node &node, const ObjectTypeSet &typeRestrictions); diff --git a/backend/codegen/ops/FnNode.cpp b/backend/codegen/ops/FnNode.cpp index 64be1512..861c04f2 100644 --- a/backend/codegen/ops/FnNode.cpp +++ b/backend/codegen/ops/FnNode.cpp @@ -59,25 +59,37 @@ TypedValue CodeGen::codegen(const Node &node, const FnNode &subnode, for (int i = 0; i < (int)methods.size(); ++i) { const FnMethodNode &m = *methods[i].node; + std::string selfBindingName = ""; + if (subnode.has_local()) { + selfBindingName = subnode.local().subnode().binding().name(); + } + // Calculate types and names for captures (closed-overs) std::vector> captureInfo; for (const auto &node : m.closedovers()) { - // Calculate type in outer scope. - // Note: getType() should work correctly here. - ObjectTypeSet type = getType(node, ObjectTypeSet::all()); - - // Determine name + // Determine name of the capture first std::string name = "unknown"; if (node.subnode().has_local()) { name = node.subnode().local().name(); } else if (node.subnode().has_binding()) { name = node.subnode().binding().name(); } + + // The analyzer places the named function's own name in its closed-overs. + // However, we bind 'self' manually inside the baseline method from the Frame. + // We must not treat it as a regular closed-over capture. + if (!selfBindingName.empty() && name == selfBindingName) { + continue; + } + + // Calculate type in outer scope. + ObjectTypeSet type = getType(node, ObjectTypeSet::all()); + captureInfo.push_back({name, type}); } // Generate the IR for the method - llvm::Function *baselineF = generateBaselineMethod(m, captureInfo); + llvm::Function *baselineF = generateBaselineMethod(m, captureInfo, selfBindingName); // Handle closed overs (captures) for the fillMethod call // (these will be boxed RTValues in the runtime) @@ -95,11 +107,21 @@ TypedValue CodeGen::codegen(const Node &node, const FnNode &subnode, fillArgs.push_back(baselineF); fillArgs.push_back( Builder.CreateGlobalString(m.loopid())); // loopId (string) - fillArgs.push_back(llvm::ConstantInt::get(types.wordTy, m.closedovers_size())); + fillArgs.push_back(llvm::ConstantInt::get(types.wordTy, captureInfo.size())); // Codegen and box each capture for (int j = 0; j < m.closedovers_size(); ++j) { - TypedValue capture = codegen(m.closedovers(j), ObjectTypeSet::all()); + const auto &node = m.closedovers(j); + std::string name = "unknown"; + if (node.subnode().has_local()) { + name = node.subnode().local().name(); + } else if (node.subnode().has_binding()) { + name = node.subnode().binding().name(); + } + if (!selfBindingName.empty() && name == selfBindingName) { + continue; + } + TypedValue capture = codegen(node, ObjectTypeSet::all()); TypedValue boxed = valueEncoder.box(capture); fillArgs.push_back(boxed.value); } diff --git a/backend/codegen/ops/LocalNode.cpp b/backend/codegen/ops/LocalNode.cpp index 33ca1627..51352c24 100644 --- a/backend/codegen/ops/LocalNode.cpp +++ b/backend/codegen/ops/LocalNode.cpp @@ -12,7 +12,8 @@ TypedValue CodeGen::codegen(const Node &node, const LocalNode &subnode, case localTypeArg: case localTypeLet: case localTypeLoop: - case localTypeCatch: { + case localTypeCatch: + case localTypeFn: { auto name = subnode.name(); auto *val = variableBindingStack.find(name); if (!val) { @@ -35,7 +36,8 @@ ObjectTypeSet CodeGen::getType(const Node &node, const LocalNode &subnode, case localTypeArg: case localTypeLet: case localTypeLoop: - case localTypeCatch: { + case localTypeCatch: + case localTypeFn: { auto name = subnode.name(); auto *type = variableTypesBindingsStack.find(name); if (!type) { diff --git a/backend/tests/codegen/FnNode_test.cpp b/backend/tests/codegen/FnNode_test.cpp index 0a3add41..67513aad 100644 --- a/backend/tests/codegen/FnNode_test.cpp +++ b/backend/tests/codegen/FnNode_test.cpp @@ -268,6 +268,163 @@ static void test_fn_bigint_arg(void **state) { }); } +// (fn zorba [x] zorba) +static void test_named_fn(void **state) { + (void)state; + ASSERT_MEMORY_ALL_BALANCED({ + rt::JITEngine engine; + + Node fnNode; + fnNode.set_op(opFn); + auto *fn = fnNode.mutable_subnode()->mutable_fn(); + fn->set_once(false); + fn->set_maxfixedarity(1); + + auto *localNode = fn->mutable_local(); + localNode->set_op(opBinding); + localNode->mutable_subnode()->mutable_binding()->set_name("zorba"); + + auto *m = fn->add_methods(); + auto *mn = m->mutable_subnode()->mutable_fnmethod(); + mn->set_fixedarity(1); + mn->set_isvariadic(false); + + auto *param = mn->add_params(); + param->set_op(opBinding); + param->mutable_subnode()->mutable_binding()->set_name("x"); + + auto *body = mn->mutable_body(); + body->set_op(opLocal); + body->mutable_subnode()->mutable_local()->set_name("zorba"); + + auto *drop = body->add_dropmemory(); + drop->set_variablename("zorba"); + drop->set_requiredrefcountchange(1); + + auto res = engine.compileAST(fnNode, "named_fn").get(); + + // Execute: returns a ClojureFunction object + RTValue funObj = res.address.toPtr()(); + assert_true(RT_isPtr(funObj)); + assert_int_equal(functionType, ::getType(funObj)); + + ClojureFunction *f = (ClojureFunction *)RT_unboxPtr(funObj); + assert_int_equal(1, f->methodCount); + + // Call it manually to check if zorba is available inside + struct ExecutionContext ctx; + memset(&ctx, 0, sizeof(ctx)); + RTValue arg = RT_boxInt32(1); + RTValue ret = RT_invokeMethod(&ctx, funObj, &f->methods[0], &arg, 1); + + // The body returns `zorba`, which should be the same function object + assert_true(RT_isPtr(ret)); + assert_int_equal(functionType, ::getType(ret)); + + // They should be equal + assert_ptr_equal(RT_unboxPtr(funObj), RT_unboxPtr(ret)); + + release(funObj); + release(ret); + }); +} + +// (fn zorba [x] (fn borba [y] zorba)) +static void test_nested_named_fn(void **state) { + (void)state; + ASSERT_MEMORY_ALL_BALANCED({ + rt::JITEngine engine; + + Node outerFnNode; + outerFnNode.set_op(opFn); + auto *outerFn = outerFnNode.mutable_subnode()->mutable_fn(); + outerFn->set_once(false); + outerFn->set_maxfixedarity(1); + + auto *outerLocal = outerFn->mutable_local(); + outerLocal->set_op(opBinding); + outerLocal->mutable_subnode()->mutable_binding()->set_name("zorba"); + + auto *outerM = outerFn->add_methods(); + auto *outerMn = outerM->mutable_subnode()->mutable_fnmethod(); + outerMn->set_fixedarity(1); + outerMn->set_isvariadic(false); + + auto *outerParam = outerMn->add_params(); + outerParam->set_op(opBinding); + outerParam->mutable_subnode()->mutable_binding()->set_name("x"); + + // The body of the outer function is the inner function + auto *innerFnBody = outerMn->mutable_body(); + innerFnBody->set_op(opFn); + auto *innerFn = innerFnBody->mutable_subnode()->mutable_fn(); + innerFn->set_once(false); + innerFn->set_maxfixedarity(1); + + auto *innerLocal = innerFn->mutable_local(); + innerLocal->set_op(opBinding); + innerLocal->mutable_subnode()->mutable_binding()->set_name("borba"); + + auto *innerFnDrop = innerFnBody->add_dropmemory(); + innerFnDrop->set_variablename("zorba"); + innerFnDrop->set_requiredrefcountchange(1); + + auto *innerM = innerFn->add_methods(); + auto *innerMn = innerM->mutable_subnode()->mutable_fnmethod(); + innerMn->set_fixedarity(1); + innerMn->set_isvariadic(false); + + auto *innerParam = innerMn->add_params(); + innerParam->set_op(opBinding); + innerParam->mutable_subnode()->mutable_binding()->set_name("y"); + + // The inner function captures 'zorba' + auto *co = innerMn->add_closedovers(); + co->set_op(opLocal); + co->mutable_subnode()->mutable_local()->set_name("zorba"); + + auto *innerBody = innerMn->mutable_body(); + innerBody->set_op(opLocal); + innerBody->mutable_subnode()->mutable_local()->set_name("zorba"); + + // We don't bother setting up dropmemory for this manual test + // to keep it simple, it'll just test the compilation path + + auto res = engine.compileAST(outerFnNode, "nested_named_fn").get(); + + // Execute: returns a ClojureFunction object + RTValue outerFunObj = res.address.toPtr()(); + assert_true(RT_isPtr(outerFunObj)); + assert_int_equal(functionType, ::getType(outerFunObj)); + + ClojureFunction *outerF = (ClojureFunction *)RT_unboxPtr(outerFunObj); + assert_int_equal(1, outerF->methodCount); + + // Call outer function to get inner function + struct ExecutionContext ctx; + memset(&ctx, 0, sizeof(ctx)); + RTValue arg = RT_boxInt32(1); + RTValue innerFunObj = RT_invokeMethod(&ctx, outerFunObj, &outerF->methods[0], &arg, 1); + + assert_true(RT_isPtr(innerFunObj)); + assert_int_equal(functionType, ::getType(innerFunObj)); + + ClojureFunction *innerF = (ClojureFunction *)RT_unboxPtr(innerFunObj); + assert_int_equal(1, innerF->methodCount); + // inner function should have 1 closed-over capture (zorba) + assert_int_equal(1, innerF->methods[0].closedOversCount); + + // Call inner function to return zorba + RTValue retZorba = RT_invokeMethod(&ctx, innerFunObj, &innerF->methods[0], &arg, 1); + + // It should be the same as outerFunObj + assert_ptr_equal(RT_unboxPtr(outerFunObj), RT_unboxPtr(retZorba)); + + release(outerFunObj); + release(innerFunObj); + }); +} + int main(void) { initialise_memory(); const struct CMUnitTest tests[] = { @@ -275,6 +432,8 @@ int main(void) { cmocka_unit_test(test_fn_capture_unboxing_int), cmocka_unit_test(test_multi_arity_fn), cmocka_unit_test(test_fn_bigint_arg), + cmocka_unit_test(test_named_fn), + cmocka_unit_test(test_nested_named_fn), }; return cmocka_run_group_tests(tests, NULL, NULL);