|
3 | 3 | #include <regex>
|
4 | 4 | #include <set>
|
5 | 5 | #include <unordered_map>
|
| 6 | +#include <unordered_set> |
6 | 7 |
|
7 | 8 | void build_graph_linear(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
8 | 9 | const std::string& json_path, bool comments,
|
@@ -288,6 +289,10 @@ void build_graph(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
288 | 289 | it_lab_ai::ImplType impl1 = parallel ? it_lab_ai::kTBB : it_lab_ai::kDefault;
|
289 | 290 | it_lab_ai::ImplType impl2 = parallel ? it_lab_ai::kSTL : it_lab_ai::kDefault;
|
290 | 291 |
|
| 292 | + std::unordered_map<std::string, std::vector<std::string>> concat_connections; |
| 293 | + std::unordered_map<std::string, std::vector<int>> concat_orders; |
| 294 | + std::unordered_map<std::string, std::unordered_set<std::string>> concat_connected_inputs; |
| 295 | + |
291 | 296 | std::unordered_map<std::string, std::vector<int64_t>> layer_parameters;
|
292 | 297 | std::unordered_map<std::string, float> float_parameters;
|
293 | 298 | std::string last_constant_name;
|
@@ -557,17 +562,39 @@ void build_graph(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
557 | 562 | if (layer_data["attributes"].contains("axis")) {
|
558 | 563 | axis = layer_data["attributes"]["axis"];
|
559 | 564 | }
|
| 565 | + if (layer_data.contains("inputs")) { |
| 566 | + for (const auto& input_name : layer_data["inputs"]) { |
| 567 | + std::string input_tensor = input_name.get<std::string>(); |
| 568 | + std::string base_input_name = get_base_layer_name(input_tensor); |
| 569 | + concat_connections[layer_name].push_back(base_input_name); |
| 570 | + } |
| 571 | + } |
560 | 572 | auto concat_layer = std::make_shared<it_lab_ai::ConcatLayer>(axis);
|
561 | 573 | concat_layer->setName(it_lab_ai::kConcat);
|
562 | 574 | layer = concat_layer;
|
| 575 | + concat_connected_inputs[layer_name] = std::unordered_set<std::string>(); |
563 | 576 | } else if (layer_type == "Split") {
|
564 | 577 | int axis = 0;
|
565 |
| - std::vector<int> splits; |
| 578 | + std::vector<int64_t> splits; |
566 | 579 | size_t num_outputs = 2;
|
567 | 580 |
|
568 | 581 | if (layer_data["attributes"].contains("axis")) {
|
569 | 582 | axis = layer_data["attributes"]["axis"];
|
570 | 583 | }
|
| 584 | + if (layer_data.contains("inputs") && layer_data["inputs"].is_array()) { |
| 585 | + auto inputs = layer_data["inputs"]; |
| 586 | + if (inputs.size() >= 2) { |
| 587 | + std::string constant_name = inputs[1].get<std::string>(); |
| 588 | + constant_name = get_base_layer_name(constant_name); |
| 589 | + |
| 590 | + if (layer_parameters.count(constant_name)) { |
| 591 | + splits = layer_parameters[constant_name]; |
| 592 | + } else if (constant_name.find("onnx::") != constant_name.npos) { |
| 593 | + splits = last_constant_value; |
| 594 | + layer_parameters[constant_name] = last_constant_value; |
| 595 | + } |
| 596 | + } |
| 597 | + } |
571 | 598 | if (layer_data.contains("weights") &&
|
572 | 599 | layer_data["weights"].is_array()) {
|
573 | 600 | for (const auto& s : layer_data["weights"]) {
|
@@ -642,6 +669,10 @@ void build_graph(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
642 | 669 | std::make_shared<it_lab_ai::EWLayer>(ew_operation, value, 0.0f);
|
643 | 670 | ew_layer->setName(it_lab_ai::kElementWise);
|
644 | 671 | layer = ew_layer;
|
| 672 | + /*if (comments) { |
| 673 | + std::cout << "Created binary " << layer_type << " operation with " |
| 674 | + << value <<"scalar" << std::endl; |
| 675 | + }*/ |
645 | 676 | } else if (layer_type == "Add") {
|
646 | 677 | ew_operation = "linear";
|
647 | 678 | auto ew_layer =
|
@@ -676,11 +707,7 @@ void build_graph(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
676 | 707 | auto bin_layer = std::make_shared<it_lab_ai::BinaryOpLayer>(op);
|
677 | 708 | bin_layer->setName(it_lab_ai::kBinaryOp);
|
678 | 709 | layer = bin_layer;
|
679 |
| - |
680 |
| - if (comments) { |
681 |
| - std::cout << "Created binary " << layer_type |
682 |
| - << " operation with tensor inputs" << std::endl; |
683 |
| - } |
| 710 | + |
684 | 711 | }
|
685 | 712 | } else if (layer_type == "Gemm") {
|
686 | 713 | it_lab_ai::Tensor tensor = it_lab_ai::create_tensor_from_json(
|
@@ -1019,6 +1046,8 @@ void build_graph(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
1019 | 1046 | for (const auto& input_name : layer_data["inputs"]) {
|
1020 | 1047 | std::string input_tensor = input_name.get<std::string>();
|
1021 | 1048 |
|
| 1049 | + |
| 1050 | + |
1022 | 1051 | // Проверяем, является ли вход выходом сплит-слоя
|
1023 | 1052 | std::regex split_output_pattern("(.+)_output_(\\d+)$");
|
1024 | 1053 | std::smatch matches;
|
@@ -1206,50 +1235,78 @@ void build_graph(it_lab_ai::Tensor& input, it_lab_ai::Tensor& output,
|
1206 | 1235 | /*if (comments) {
|
1207 | 1236 | std::cout << "\n=== ESTABLISHING CONNECTIONS ===" << std::endl;
|
1208 | 1237 | }*/
|
| 1238 | + std::vector<int> order = {}; |
1209 | 1239 |
|
1210 | 1240 | for (const auto& [source_name, target_name] : connection_list) {
|
1211 |
| - // Убираем проверку на сплит-выходы - они тоже должны быть подключены |
1212 |
| - |
1213 | 1241 | if (name_to_layer.count(source_name) && name_to_layer.count(target_name)) {
|
| 1242 | + // Обработка Concat слоев |
| 1243 | + if (target_name.find("Concat") != std::string::npos || |
| 1244 | + name_to_layer[target_name]->getName() == it_lab_ai::kConcat) { |
| 1245 | + // Проверяем, есть ли этот concat в нашем списке |
| 1246 | + if (concat_connections.find(target_name) != concat_connections.end()) { |
| 1247 | + // Находим индекс этого источника в ожидаемых входах concat |
| 1248 | + const auto& expected_inputs = concat_connections[target_name]; |
| 1249 | + auto it = std::find(expected_inputs.begin(), expected_inputs.end(), |
| 1250 | + source_name); |
| 1251 | + |
| 1252 | + if (it != expected_inputs.end()) { |
| 1253 | + int input_index = static_cast<int>(std::distance(expected_inputs.begin(), it)); |
| 1254 | + |
| 1255 | + // Добавляем индекс в порядок для этого concat |
| 1256 | + concat_orders[target_name].push_back(input_index); |
| 1257 | + |
| 1258 | + // Отмечаем, что этот вход подключен |
| 1259 | + concat_connected_inputs[target_name].insert(source_name); |
| 1260 | + |
| 1261 | + if (comments) { |
| 1262 | + std::cout << "Concat connection: " << source_name << " -> " |
| 1263 | + << target_name << " (index: " << input_index << ")" |
| 1264 | + << std::endl; |
| 1265 | + } |
| 1266 | + |
| 1267 | + // Проверяем, все ли входы подключены |
| 1268 | + if (concat_connected_inputs[target_name].size() == |
| 1269 | + concat_connections[target_name].size()) { |
| 1270 | + // Все входы подключены - устанавливаем порядок |
| 1271 | + auto concat_layer = |
| 1272 | + std::dynamic_pointer_cast<it_lab_ai::ConcatLayer>( |
| 1273 | + name_to_layer[target_name]); |
| 1274 | + if (concat_layer) { |
| 1275 | + concat_layer->setInputOrder(concat_orders[target_name]); |
| 1276 | + |
| 1277 | + if (comments) { |
| 1278 | + std::cout |
| 1279 | + << "=== ALL INPUTS CONNECTED TO CONCAT: " << target_name |
| 1280 | + << " ===" << std::endl; |
| 1281 | + std::cout << "Expected inputs: "; |
| 1282 | + for (const auto& inp : concat_connections[target_name]) { |
| 1283 | + std::cout << inp << " "; |
| 1284 | + } |
| 1285 | + std::cout << std::endl; |
| 1286 | + |
| 1287 | + std::cout << "Actual order: "; |
| 1288 | + for (size_t i = 0; i < concat_orders[target_name].size(); |
| 1289 | + ++i) { |
| 1290 | + std::cout << concat_orders[target_name][i]; |
| 1291 | + if (i < concat_orders[target_name].size() - 1) |
| 1292 | + std::cout << ", "; |
| 1293 | + } |
| 1294 | + std::cout << std::endl; |
| 1295 | + } |
| 1296 | + } |
| 1297 | + } |
| 1298 | + } |
| 1299 | + } |
| 1300 | + } |
| 1301 | + |
1214 | 1302 | try {
|
1215 |
| - //if (comments) { |
1216 |
| - // std::cout << "Connecting: " << source_name << " -> " << target_name; |
1217 |
| - // std::cout << " (ID: " << name_to_layer[source_name]->getID() |
1218 |
| - // << " -> ID: " << name_to_layer[target_name]->getID() << ")" |
1219 |
| - // << std::endl; |
1220 |
| - |
1221 |
| - // // Дополнительная информация для сплит-соединений |
1222 |
| - // std::regex split_output_pattern("(.+)_output_(\\d+)$"); |
1223 |
| - // std::smatch matches; |
1224 |
| - // if (std::regex_search(source_name, matches, split_output_pattern)) { |
1225 |
| - // std::string split_layer_name = matches[1].str(); |
1226 |
| - // int output_index = std::stoi(matches[2].str()); |
1227 |
| - // std::cout << " [SPLIT] Output index: " << output_index |
1228 |
| - // << std::endl; |
1229 |
| - // } |
1230 |
| - //} |
1231 | 1303 | graph.makeConnection(*name_to_layer[source_name],
|
1232 | 1304 | *name_to_layer[target_name]);
|
1233 |
| - /*if (comments) { |
1234 |
| - std::cout << " Success" << std::endl; |
1235 |
| - }*/ |
| 1305 | + |
1236 | 1306 | } catch (const std::exception& e) {
|
1237 | 1307 | std::cerr << "Failed: " << source_name << " -> " << target_name << " : "
|
1238 | 1308 | << e.what() << std::endl;
|
1239 | 1309 | }
|
1240 |
| - } else { |
1241 |
| - /*if (comments) { |
1242 |
| - std::cerr << "Warning: Missing layer for connection " << source_name |
1243 |
| - << " -> " << target_name << std::endl; |
1244 |
| - if (!name_to_layer.count(source_name)) { |
1245 |
| - std::cerr << " Source layer '" << source_name << "' not found" |
1246 |
| - << std::endl; |
1247 |
| - } |
1248 |
| - if (!name_to_layer.count(target_name)) { |
1249 |
| - std::cerr << " Target layer '" << target_name << "' not found" |
1250 |
| - << std::endl; |
1251 |
| - } |
1252 |
| - }*/ |
1253 | 1310 | }
|
1254 | 1311 | }
|
1255 | 1312 | for (auto& split_dist : split_distribution) {
|
|
0 commit comments