Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backend/codegen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ std::string CodeGen::generateInstanceCallBridge(

llvm::Function *CodeGen::generateBaselineMethod(
const FnMethodNode &method,
const std::vector<std::pair<std::string, ObjectTypeSet>> &captureInfo) {
const std::vector<std::pair<std::string, ObjectTypeSet>> &captureInfo,
const std::string &selfBindingName) {
CLJ_ASSERT(TSContext != nullptr, "Codegen was moved");

std::string funcName = "fn_";
Expand Down Expand Up @@ -298,6 +299,13 @@ llvm::Function *CodeGen::generateBaselineMethod(
variableTypesBindingsStack.push();
functionMetricsStack.push_back({0, false});

if (!selfBindingName.empty()) {
Value *selfPtr = types.getFrameSelfPtr(Builder, framePtr);
Value *selfVal = Builder.CreateLoad(types.RT_valueTy, selfPtr, "self");
variableBindingStack.set(selfBindingName, TypedValue(ObjectTypeSet(functionType, false), selfVal));
variableTypesBindingsStack.set(selfBindingName, ObjectTypeSet(functionType, false));
}

std::string loopId = method.loopid();
if (!loopId.empty()) {
recurTargets[loopId] = {F, framePtr, method.fixedarity(),
Expand Down
3 changes: 2 additions & 1 deletion backend/codegen/CodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class CodeGen {

llvm::Function *generateBaselineMethod(
const FnMethodNode &method,
const std::vector<std::pair<std::string, ObjectTypeSet>> &captureInfo);
const std::vector<std::pair<std::string, ObjectTypeSet>> &captureInfo,
const std::string &selfBindingName = "");

TypedValue codegen(const Node &node, const ObjectTypeSet &typeRestrictions);

Expand Down
38 changes: 30 additions & 8 deletions backend/codegen/ops/FnNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,37 @@ TypedValue CodeGen::codegen(const Node &node, const FnNode &subnode,
for (int i = 0; i < (int)methods.size(); ++i) {
const FnMethodNode &m = *methods[i].node;

std::string selfBindingName = "";
if (subnode.has_local()) {
selfBindingName = subnode.local().subnode().binding().name();
}

// Calculate types and names for captures (closed-overs)
std::vector<std::pair<std::string, ObjectTypeSet>> captureInfo;
for (const auto &node : m.closedovers()) {
// Calculate type in outer scope.
// Note: getType() should work correctly here.
ObjectTypeSet type = getType(node, ObjectTypeSet::all());

// Determine name
// Determine name of the capture first
std::string name = "unknown";
if (node.subnode().has_local()) {
name = node.subnode().local().name();
} else if (node.subnode().has_binding()) {
name = node.subnode().binding().name();
}

// The analyzer places the named function's own name in its closed-overs.
// However, we bind 'self' manually inside the baseline method from the Frame.
// We must not treat it as a regular closed-over capture.
if (!selfBindingName.empty() && name == selfBindingName) {
continue;
}

// Calculate type in outer scope.
ObjectTypeSet type = getType(node, ObjectTypeSet::all());

captureInfo.push_back({name, type});
}

// Generate the IR for the method
llvm::Function *baselineF = generateBaselineMethod(m, captureInfo);
llvm::Function *baselineF = generateBaselineMethod(m, captureInfo, selfBindingName);

// Handle closed overs (captures) for the fillMethod call
// (these will be boxed RTValues in the runtime)
Expand All @@ -95,11 +107,21 @@ TypedValue CodeGen::codegen(const Node &node, const FnNode &subnode,
fillArgs.push_back(baselineF);
fillArgs.push_back(
Builder.CreateGlobalString(m.loopid())); // loopId (string)
fillArgs.push_back(llvm::ConstantInt::get(types.wordTy, m.closedovers_size()));
fillArgs.push_back(llvm::ConstantInt::get(types.wordTy, captureInfo.size()));

// Codegen and box each capture
for (int j = 0; j < m.closedovers_size(); ++j) {
TypedValue capture = codegen(m.closedovers(j), ObjectTypeSet::all());
const auto &node = m.closedovers(j);
std::string name = "unknown";
if (node.subnode().has_local()) {
name = node.subnode().local().name();
} else if (node.subnode().has_binding()) {
name = node.subnode().binding().name();
}
if (!selfBindingName.empty() && name == selfBindingName) {
continue;
}
TypedValue capture = codegen(node, ObjectTypeSet::all());
TypedValue boxed = valueEncoder.box(capture);
fillArgs.push_back(boxed.value);
}
Expand Down
6 changes: 4 additions & 2 deletions backend/codegen/ops/LocalNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ TypedValue CodeGen::codegen(const Node &node, const LocalNode &subnode,
case localTypeArg:
case localTypeLet:
case localTypeLoop:
case localTypeCatch: {
case localTypeCatch:
case localTypeFn: {
auto name = subnode.name();
auto *val = variableBindingStack.find(name);
if (!val) {
Expand All @@ -35,7 +36,8 @@ ObjectTypeSet CodeGen::getType(const Node &node, const LocalNode &subnode,
case localTypeArg:
case localTypeLet:
case localTypeLoop:
case localTypeCatch: {
case localTypeCatch:
case localTypeFn: {
auto name = subnode.name();
auto *type = variableTypesBindingsStack.find(name);
if (!type) {
Expand Down
159 changes: 159 additions & 0 deletions backend/tests/codegen/FnNode_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,172 @@ static void test_fn_bigint_arg(void **state) {
});
}

// (fn zorba [x] zorba)
static void test_named_fn(void **state) {
(void)state;
ASSERT_MEMORY_ALL_BALANCED({
rt::JITEngine engine;

Node fnNode;
fnNode.set_op(opFn);
auto *fn = fnNode.mutable_subnode()->mutable_fn();
fn->set_once(false);
fn->set_maxfixedarity(1);

auto *localNode = fn->mutable_local();
localNode->set_op(opBinding);
localNode->mutable_subnode()->mutable_binding()->set_name("zorba");

auto *m = fn->add_methods();
auto *mn = m->mutable_subnode()->mutable_fnmethod();
mn->set_fixedarity(1);
mn->set_isvariadic(false);

auto *param = mn->add_params();
param->set_op(opBinding);
param->mutable_subnode()->mutable_binding()->set_name("x");

auto *body = mn->mutable_body();
body->set_op(opLocal);
body->mutable_subnode()->mutable_local()->set_name("zorba");

auto *drop = body->add_dropmemory();
drop->set_variablename("zorba");
drop->set_requiredrefcountchange(1);

auto res = engine.compileAST(fnNode, "named_fn").get();

// Execute: returns a ClojureFunction object
RTValue funObj = res.address.toPtr<RTValue (*)()>()();
assert_true(RT_isPtr(funObj));
assert_int_equal(functionType, ::getType(funObj));

ClojureFunction *f = (ClojureFunction *)RT_unboxPtr(funObj);
assert_int_equal(1, f->methodCount);

// Call it manually to check if zorba is available inside
struct ExecutionContext ctx;
memset(&ctx, 0, sizeof(ctx));
RTValue arg = RT_boxInt32(1);
RTValue ret = RT_invokeMethod(&ctx, funObj, &f->methods[0], &arg, 1);

// The body returns `zorba`, which should be the same function object
assert_true(RT_isPtr(ret));
assert_int_equal(functionType, ::getType(ret));

// They should be equal
assert_ptr_equal(RT_unboxPtr(funObj), RT_unboxPtr(ret));

release(funObj);
release(ret);
});
}

// (fn zorba [x] (fn borba [y] zorba))
static void test_nested_named_fn(void **state) {
(void)state;
ASSERT_MEMORY_ALL_BALANCED({
rt::JITEngine engine;

Node outerFnNode;
outerFnNode.set_op(opFn);
auto *outerFn = outerFnNode.mutable_subnode()->mutable_fn();
outerFn->set_once(false);
outerFn->set_maxfixedarity(1);

auto *outerLocal = outerFn->mutable_local();
outerLocal->set_op(opBinding);
outerLocal->mutable_subnode()->mutable_binding()->set_name("zorba");

auto *outerM = outerFn->add_methods();
auto *outerMn = outerM->mutable_subnode()->mutable_fnmethod();
outerMn->set_fixedarity(1);
outerMn->set_isvariadic(false);

auto *outerParam = outerMn->add_params();
outerParam->set_op(opBinding);
outerParam->mutable_subnode()->mutable_binding()->set_name("x");

// The body of the outer function is the inner function
auto *innerFnBody = outerMn->mutable_body();
innerFnBody->set_op(opFn);
auto *innerFn = innerFnBody->mutable_subnode()->mutable_fn();
innerFn->set_once(false);
innerFn->set_maxfixedarity(1);

auto *innerLocal = innerFn->mutable_local();
innerLocal->set_op(opBinding);
innerLocal->mutable_subnode()->mutable_binding()->set_name("borba");

auto *innerFnDrop = innerFnBody->add_dropmemory();
innerFnDrop->set_variablename("zorba");
innerFnDrop->set_requiredrefcountchange(1);

auto *innerM = innerFn->add_methods();
auto *innerMn = innerM->mutable_subnode()->mutable_fnmethod();
innerMn->set_fixedarity(1);
innerMn->set_isvariadic(false);

auto *innerParam = innerMn->add_params();
innerParam->set_op(opBinding);
innerParam->mutable_subnode()->mutable_binding()->set_name("y");

// The inner function captures 'zorba'
auto *co = innerMn->add_closedovers();
co->set_op(opLocal);
co->mutable_subnode()->mutable_local()->set_name("zorba");

auto *innerBody = innerMn->mutable_body();
innerBody->set_op(opLocal);
innerBody->mutable_subnode()->mutable_local()->set_name("zorba");

// We don't bother setting up dropmemory for this manual test
// to keep it simple, it'll just test the compilation path

auto res = engine.compileAST(outerFnNode, "nested_named_fn").get();

// Execute: returns a ClojureFunction object
RTValue outerFunObj = res.address.toPtr<RTValue (*)()>()();
assert_true(RT_isPtr(outerFunObj));
assert_int_equal(functionType, ::getType(outerFunObj));

ClojureFunction *outerF = (ClojureFunction *)RT_unboxPtr(outerFunObj);
assert_int_equal(1, outerF->methodCount);

// Call outer function to get inner function
struct ExecutionContext ctx;
memset(&ctx, 0, sizeof(ctx));
RTValue arg = RT_boxInt32(1);
RTValue innerFunObj = RT_invokeMethod(&ctx, outerFunObj, &outerF->methods[0], &arg, 1);

assert_true(RT_isPtr(innerFunObj));
assert_int_equal(functionType, ::getType(innerFunObj));

ClojureFunction *innerF = (ClojureFunction *)RT_unboxPtr(innerFunObj);
assert_int_equal(1, innerF->methodCount);
// inner function should have 1 closed-over capture (zorba)
assert_int_equal(1, innerF->methods[0].closedOversCount);

// Call inner function to return zorba
RTValue retZorba = RT_invokeMethod(&ctx, innerFunObj, &innerF->methods[0], &arg, 1);

// It should be the same as outerFunObj
assert_ptr_equal(RT_unboxPtr(outerFunObj), RT_unboxPtr(retZorba));

release(outerFunObj);
release(innerFunObj);
});
}

int main(void) {
initialise_memory();
const struct CMUnitTest tests[] = {
cmocka_unit_test(test_simple_fn),
cmocka_unit_test(test_fn_capture_unboxing_int),
cmocka_unit_test(test_multi_arity_fn),
cmocka_unit_test(test_fn_bigint_arg),
cmocka_unit_test(test_named_fn),
cmocka_unit_test(test_nested_named_fn),
};

return cmocka_run_group_tests(tests, NULL, NULL);
Expand Down
Loading