diff --git a/src/sd_bot.cpp b/src/sd_bot.cpp index 9c586be..198740d 100644 --- a/src/sd_bot.cpp +++ b/src/sd_bot.cpp @@ -11,9 +11,124 @@ #include #include #include -#include + #include +struct Automatic1111_v1_Endpoint : public SDBot::EndpointI { + Automatic1111_v1_Endpoint(RegistryMessageModel& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {} + + bool handleResponse(Contact3 contact, ByteSpan data) override { + //std::cout << std::string_view{reinterpret_cast(data.ptr), data.size} << "\n"; + + // extract json result + const auto j = nlohmann::json::parse( + std::string_view{reinterpret_cast(data.ptr), data.size}, + nullptr, + false + ); + //std::cout << "json dump: " << j.dump() << "\n"; + + if (j.count("images") && !j.at("images").empty() && j.at("images").is_array()) { + for (const auto& i_j : j.at("images").items()) { + // decode data (base64) + std::vector png_data(data.size); // just init to upper bound + size_t decoded_size {0}; + sodium_base642bin( + png_data.data(), png_data.size(), + i_j.value().get().data(), i_j.value().get().size(), + " \n\t", + &decoded_size, + nullptr, + sodium_base64_VARIANT_ORIGINAL + ); + png_data.resize(decoded_size); + + std::filesystem::create_directories("sdbot_img_send"); + //const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; + const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; + const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; + + std::ofstream(tmp_img_file_path).write(reinterpret_cast(png_data.data()), png_data.size()); + _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); + } + } else { + std::cerr << "SDB json response did not contain images?\n"; + return false; + } + + return true; + } +}; + +struct SDcpp_wip1_Endpoint : public SDBot::EndpointI { + SDcpp_wip1_Endpoint(RegistryMessageModel& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {} + + bool handleResponse(Contact3 contact, ByteSpan data) override { + //std::cout << std::string_view{reinterpret_cast(data.ptr), data.size} << "\n"; + + std::string_view data_str {reinterpret_cast(data.ptr), data.size}; + auto nl_pos {std::string_view::npos}; + bool succ {false}; + do { + // for each line, should be "data: " or empty + nl_pos = data_str.find_first_of('\n'); + + // npos is also valid + auto line = data_str.substr(0, nl_pos); + + // at least minimum viable + if (line.size() >= std::string_view{"data: {}"}.size()) { + //std::cout << "got data line!!!!!!!!!!!:\n"; + //std::cout << line << "\n"; + line.remove_prefix(6); + + const auto j = nlohmann::json::parse( + line, + nullptr, + false + ); + + if ( + !j.empty() && + j.value("type", "notimag") == "image" && + j.contains("data") && + j.at("data").is_string() + ) { + const auto& img_data_str = j.at("data").get(); + // decode data (base64) + std::vector png_data(img_data_str.size()); // just init to upper bound + size_t decoded_size {0}; + sodium_base642bin( + png_data.data(), png_data.size(), + img_data_str.data(), img_data_str.size(), + " \n\t", + &decoded_size, + nullptr, + sodium_base64_VARIANT_ORIGINAL + ); + png_data.resize(decoded_size); + + std::filesystem::create_directories("sdbot_img_send"); + //const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; + const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; + const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; + + std::ofstream(tmp_img_file_path).write(reinterpret_cast(png_data.data()), png_data.size()); + succ = _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); + } + } + + if (nl_pos == std::string_view::npos || nl_pos+1 >= data_str.size()) { + break; + } + + data_str = data_str.substr(nl_pos+1); + } while (nl_pos != std::string_view::npos); + + return succ; + } +}; + SDBot::SDBot( Contact3Registry& cr, RegistryMessageModel& rmm, @@ -22,6 +137,26 @@ SDBot::SDBot( _rng.seed(std::random_device{}()); _rng.discard(3137); + if (!_conf.has_string("SDBot", "endpoint_type")) { + _conf.set("SDBot", "endpoint_type", std::string_view{"automatic1111_v1"}); // automatic11 default + } + + //HACKy + { // construct endpoint + const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value(); + if (endpoint_type == "automatic1111_v1") { + _endpoint = std::make_unique(_rmm, _rng); + } else if (endpoint_type == "sdcpp_wip1") { + _endpoint = std::make_unique(_rmm, _rng); + } else { + std::cerr << "SDB error: unknown endpoint type '" << endpoint_type << "'\n"; + // TODO: throw? + _endpoint = std::make_unique(_rmm, _rng); + } + } + + // TODO: use defaults based on endpoint_type + if (!_conf.has_string("SDBot", "server_host")) { _conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"}); } @@ -219,44 +354,10 @@ void SDBot::onHttpComplete(const happyhttp::Response* r) { std::cout << "SDB http complete " << r->getstatus() << " " << r->getreason() << "\n"; if (r->getstatus() == happyhttp::OK) { std::cout << "SDB data\n"; - //std::cout << std::string_view{reinterpret_cast(_con_data.data()), _con_data.size()} << "\n"; - // extract json result - const auto j = nlohmann::json::parse( - std::string_view{reinterpret_cast(_con_data.data()), _con_data.size()}, - nullptr, - false - ); - //std::cout << "json dump: " << j.dump() << "\n"; - - if (j.count("images") && !j.at("images").empty() && j.at("images").is_array()) { - for (const auto& i_j : j.at("images").items()) { - // decode data (base64) - std::vector png_data(_con_data.size()); // just init to upper bound - size_t decoded_size {0}; - sodium_base642bin( - png_data.data(), png_data.size(), - i_j.value().get().data(), i_j.value().get().size(), - " \n\t", - &decoded_size, - nullptr, - sodium_base64_VARIANT_ORIGINAL - ); - png_data.resize(decoded_size); - - // hand png to download manager - const auto& contact = _task_map.at(_current_task.value()); - - std::filesystem::create_directories("sdbot_img_send"); - //const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; - const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; - const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; - - std::ofstream(tmp_img_file_path).write(reinterpret_cast(png_data.data()), png_data.size()); - _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); - } - } else { - std::cerr << "SDB json response did not contain images?\n"; + const auto& contact = _task_map.at(_current_task.value()); + if (_endpoint->handleResponse(contact, ByteSpan{_con_data})) { + // error? } _task_map.erase(_current_task.value()); diff --git a/src/sd_bot.hpp b/src/sd_bot.hpp index d43a33d..08f692b 100644 --- a/src/sd_bot.hpp +++ b/src/sd_bot.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -33,6 +34,19 @@ class SDBot : public RegistryMessageModelEventI { std::default_random_engine _rng; + public: + struct EndpointI { + RegistryMessageModel& _rmm; + std::default_random_engine& _rng; + EndpointI(RegistryMessageModel& rmm, std::default_random_engine& rng) : _rmm(rmm), _rng(rng) {} + virtual ~EndpointI(void) {} + + virtual bool handleResponse(Contact3 contact, ByteSpan data) = 0; + }; + + private: + std::unique_ptr _endpoint; + public: SDBot( Contact3Registry& cr,