diff --git a/src/printer/IRNodeFinder.cpp b/src/printer/IRNodeFinder.cpp index d526559..54b9df5 100644 --- a/src/printer/IRNodeFinder.cpp +++ b/src/printer/IRNodeFinder.cpp @@ -10,11 +10,15 @@ #include #include +#include #include #include #include #include +#include +#include #include +#include using namespace llvm; @@ -69,33 +73,67 @@ void IRNodeFinder::printFunction(const std::string& regex) const { applyToMatchingFunction(os, m, regex, [&](const Function* f) { f->print(os); }); } -void IRNodeFinder::printByLocation(unsigned line_start_, unsigned line_end_) const { - line_end_ = std::max(line_start_, line_end_); - std::string matches; - llvm::raw_string_ostream local_oss{matches}; - bool first_match{false}; - const auto* m = tool.getModule(); - for (const auto& f : *m) { - first_match = false; - for (const auto& bb : f) { - for (const auto& inst : bb) { +void IRNodeFinder::printByLocation(unsigned line_start, unsigned line_end) const { + const unsigned search_end = std::max(line_start, line_end); + const auto* module = tool.getModule(); + const auto main_file_path = [&](const llvm::Module* m) -> std::optional { + auto* CUs = m->getNamedMetadata("llvm.dbg.cu"); + if (!CUs || CUs->getNumOperands() == 0) { + return std::nullopt; + } + auto* cu = llvm::cast(CUs->getOperand(0)); + llvm::SmallString<128> path; + if (!llvm::sys::fs::real_path(cu->getFilename(), path)) { + return path.str().str(); + } + return std::nullopt; + }(module); + + const auto is_relevant_function = [&](const llvm::Function& func) -> bool { + const auto* sub = func.getSubprogram(); + if (!sub) { + return false; + } + // TODO: investigate w.r.t. inlining? + // if (sub->getLine() > search_end) { + // return false; + // } + if (main_file_path) { + llvm::SmallString<128> func_path; + if (!llvm::sys::fs::real_path(sub->getFilename(), func_path)) { + return func_path == *main_file_path; + } + } + return true; + }; + + std::string matches_buffer; + llvm::raw_string_ostream buffer_stream{matches_buffer}; + + for (const auto& func : llvm::make_filter_range(*module, is_relevant_function)) { + bool function_header_printed = false; + for (const auto& block : func) { + for (const auto& inst : block) { const auto& loc = inst.getDebugLoc(); - if (loc) { - const auto line = loc.getLine(); - if (line >= line_start_ && line <= line_end_) { - if (!first_match) { - local_oss << f.getName() << ":\n"; - first_match = true; - } - inst.print(local_oss); - local_oss << "\n"; + if (!loc) { + continue; + } + const unsigned line = loc.getLine(); + + if (line >= line_start && line <= search_end) { + if (!function_header_printed) { + buffer_stream << func.getName() << ":\n"; + function_header_printed = true; } + inst.print(buffer_stream); + buffer_stream << "\n"; } } } } - if (!local_oss.str().empty()) { - os << local_oss.str() << "\n"; + + if (!buffer_stream.str().empty()) { + os << buffer_stream.str() << "\n"; } }