-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
155 lines (133 loc) · 7.21 KB
/
Copy pathmain.cpp
File metadata and controls
155 lines (133 loc) · 7.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <nlohmann/json.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/opencv.hpp>
#include <utility>
#include "input_parser.hpp"
#include "logger.hpp"
#include "model_adapter_cls.hpp"
#include "model_adapter_det.hpp"
namespace fs = std::filesystem;
namespace {
void PrintUsage()
{
std::cout
<< "Usage: ./main --image-dir /path/to/images [--output-json predictions.json]\n"
<< "Detection model flags: --model-address localhost:8001 --model-version 1 --model-name torch_cnn_model\n"
<< " --input-name images --output-name output0 --input-width 640 --input-height 480\n"
<< " --input-channels 3 --max-detections 20 --fields-per-detection 6 --score-threshold 0.25\n"
<< "Classifier flags: --classifier-model-name sklearn_classifier --classifier-input-name X\n"
<< " --classifier-output-name probabilities --classifier-input-features 8\n"
<< " --classifier-output-classes 2\n"
<< " --classifier-resize-width 128 --classifier-resize-height 128\n";
}
std::vector<std::string> GetImagePaths(const std::string &directory_path)
{
std::vector<std::string> image_paths;
for (const auto &entry : fs::directory_iterator(directory_path)) {
if (fs::is_regular_file(entry.path())) {
const std::string extension = entry.path().extension().string();
if (extension == ".jpg" || extension == ".png") {
image_paths.push_back(entry.path());
}
}
}
std::sort(image_paths.begin(), image_paths.end());
return image_paths;
}
} // namespace
int main(int argc, char **argv)
{
InputParser input(argc, argv);
if (input.CmdOptionExists("--help")) {
PrintUsage();
return 0;
}
Logger::EnableTimer(input.GetCmdOption<bool>("--enable-timer-logging", false));
Logger::EnableDebugLevel(input.GetCmdOption<bool>("--enable-debug-logging", false));
const std::string image_dir = input.GetCmdOption<std::string>("--image-dir", "/app/input/");
const std::string output_json_path = input.GetCmdOption<std::string>("--output-json", "/app/output/result.json");
if (!fs::exists(image_dir)) {
Logger::Error("Image directory does not exist: " + image_dir);
return 1;
}
ModelClientSettings model_client_settings;
model_client_settings.model_url = input.GetCmdOption<std::string>("--model-address", "localhost:8001");
model_client_settings.model_version = input.GetCmdOption<std::string>("--model-version", "1");
model_client_settings.model_name = input.GetCmdOption<std::string>("--model-name", "torch_cnn_model");
model_client_settings.cuda_shared_mem = input.GetCmdOption<bool>("--model-cuda-shared-mem", false);
ModelClientSettings classifier_model_client_settings = model_client_settings;
classifier_model_client_settings.model_name =
input.GetCmdOption<std::string>("--classifier-model-name", "sklearn_classifier");
ObjectDetectionAdapterSettings adapter_settings;
adapter_settings.input_name = input.GetCmdOption<std::string>("--input-name", "images");
adapter_settings.output_name = input.GetCmdOption<std::string>("--output-name", "output0");
adapter_settings.input_width = input.GetCmdOption<int>("--input-width", 640);
adapter_settings.input_height = input.GetCmdOption<int>("--input-height", 480);
adapter_settings.input_channels = input.GetCmdOption<int>("--input-channels", 3);
adapter_settings.max_detections = input.GetCmdOption<int>("--max-detections", 20);
adapter_settings.output_fields_per_detection = input.GetCmdOption<int>("--fields-per-detection", 6);
adapter_settings.confidence_threshold = input.GetCmdOption<float>("--score-threshold", 0.25f);
adapter_settings.boxes_are_normalized = input.GetCmdOption<bool>("--boxes-normalized", true);
adapter_settings.channel_first = input.GetCmdOption<bool>("--channel-first", false);
ModelAdpaterCassifierSettings classifier_settings;
classifier_settings.input_name = input.GetCmdOption<std::string>("--classifier-input-name", "X");
classifier_settings.output_name = input.GetCmdOption<std::string>("--classifier-output-name", "probabilities");
classifier_settings.input_feature_count = input.GetCmdOption<int>("--classifier-input-features", 8);
classifier_settings.output_class_count = input.GetCmdOption<int>("--classifier-output-classes", 2);
classifier_settings.feature_extractor_settings.resize_width =
input.GetCmdOption<int>("--classifier-resize-width", 128);
classifier_settings.feature_extractor_settings.resize_height =
input.GetCmdOption<int>("--classifier-resize-height", 128);
const std::vector<std::string> image_paths = GetImagePaths(image_dir);
if (image_paths.empty()) {
Logger::Error("No images found in directory: " + image_dir);
return 1;
}
Logger::Log("Found images: " + std::to_string(image_paths.size()));
ModelAdapterObjectDetection model_adapter(std::move(model_client_settings), std::move(adapter_settings));
ModelAdpaterCassifier image_classifier_adapter(std::move(classifier_model_client_settings),
std::move(classifier_settings));
nlohmann::json images_json = nlohmann::json::array();
for (const auto &image_path : image_paths) {
Logger::Log("Running inference on: " + image_path);
cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);
if (image.empty()) {
Logger::Error("Failed to read image: " + image_path);
return 1;
}
std::vector<DetectedObject> detections;
ClassifierResult classification_result;
try {
detections = model_adapter.Run(image);
classification_result = image_classifier_adapter.Run(image);
}
catch (const std::exception &error) {
Logger::Error("Inference failed for " + image_path + ": " + std::string(error.what()));
continue;
}
images_json.push_back(nlohmann::json{{"file_name", fs::path(image_path).filename().string()},
{"image_width", image.cols},
{"image_height", image.rows},
{"detections", detections},
{"classification",
{{"class_id", classification_result.class_id},
{"class_name", classification_result.classification},
{"confidence", classification_result.probability},
{"probabilities", classification_result.probabilities}}}});
}
const nlohmann::json payload = {{"images", images_json}};
if (!output_json_path.empty()) {
std::ofstream output_file(output_json_path);
if (!output_file.is_open()) {
Logger::Error("Failed to open output file: " + output_json_path);
return 1;
}
output_file << payload.dump(2) << std::endl;
}
std::cout << payload.dump(2) << std::endl;
return 0;
}