hacked in endpoint for sd.cpp wip
This commit is contained in:
parent
5490509fc5
commit
fba8417ddf
177
src/sd_bot.cpp
177
src/sd_bot.cpp
@ -11,9 +11,124 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <cstdint>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
struct Automatic1111_v1_Endpoint : public SDBot::EndpointI {
|
||||||
|
Automatic1111_v1_Endpoint(RegistryMessageModel& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {}
|
||||||
|
|
||||||
|
bool handleResponse(Contact3 contact, ByteSpan data) override {
|
||||||
|
//std::cout << std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} << "\n";
|
||||||
|
|
||||||
|
// extract json result
|
||||||
|
const auto j = nlohmann::json::parse(
|
||||||
|
std::string_view{reinterpret_cast<const char*>(data.ptr), data.size},
|
||||||
|
nullptr,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
//std::cout << "json dump: " << j.dump() << "\n";
|
||||||
|
|
||||||
|
if (j.count("images") && !j.at("images").empty() && j.at("images").is_array()) {
|
||||||
|
for (const auto& i_j : j.at("images").items()) {
|
||||||
|
// decode data (base64)
|
||||||
|
std::vector<uint8_t> png_data(data.size); // just init to upper bound
|
||||||
|
size_t decoded_size {0};
|
||||||
|
sodium_base642bin(
|
||||||
|
png_data.data(), png_data.size(),
|
||||||
|
i_j.value().get<std::string>().data(), i_j.value().get<std::string>().size(),
|
||||||
|
" \n\t",
|
||||||
|
&decoded_size,
|
||||||
|
nullptr,
|
||||||
|
sodium_base64_VARIANT_ORIGINAL
|
||||||
|
);
|
||||||
|
png_data.resize(decoded_size);
|
||||||
|
|
||||||
|
std::filesystem::create_directories("sdbot_img_send");
|
||||||
|
//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
|
||||||
|
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
|
||||||
|
const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
|
||||||
|
|
||||||
|
std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
|
||||||
|
_rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::cerr << "SDB json response did not contain images?\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SDcpp_wip1_Endpoint : public SDBot::EndpointI {
|
||||||
|
SDcpp_wip1_Endpoint(RegistryMessageModel& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {}
|
||||||
|
|
||||||
|
bool handleResponse(Contact3 contact, ByteSpan data) override {
|
||||||
|
//std::cout << std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} << "\n";
|
||||||
|
|
||||||
|
std::string_view data_str {reinterpret_cast<const char*>(data.ptr), data.size};
|
||||||
|
auto nl_pos {std::string_view::npos};
|
||||||
|
bool succ {false};
|
||||||
|
do {
|
||||||
|
// for each line, should be "data: <json>" or empty
|
||||||
|
nl_pos = data_str.find_first_of('\n');
|
||||||
|
|
||||||
|
// npos is also valid
|
||||||
|
auto line = data_str.substr(0, nl_pos);
|
||||||
|
|
||||||
|
// at least minimum viable
|
||||||
|
if (line.size() >= std::string_view{"data: {}"}.size()) {
|
||||||
|
//std::cout << "got data line!!!!!!!!!!!:\n";
|
||||||
|
//std::cout << line << "\n";
|
||||||
|
line.remove_prefix(6);
|
||||||
|
|
||||||
|
const auto j = nlohmann::json::parse(
|
||||||
|
line,
|
||||||
|
nullptr,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
if (
|
||||||
|
!j.empty() &&
|
||||||
|
j.value("type", "notimag") == "image" &&
|
||||||
|
j.contains("data") &&
|
||||||
|
j.at("data").is_string()
|
||||||
|
) {
|
||||||
|
const auto& img_data_str = j.at("data").get<std::string>();
|
||||||
|
// decode data (base64)
|
||||||
|
std::vector<uint8_t> png_data(img_data_str.size()); // just init to upper bound
|
||||||
|
size_t decoded_size {0};
|
||||||
|
sodium_base642bin(
|
||||||
|
png_data.data(), png_data.size(),
|
||||||
|
img_data_str.data(), img_data_str.size(),
|
||||||
|
" \n\t",
|
||||||
|
&decoded_size,
|
||||||
|
nullptr,
|
||||||
|
sodium_base64_VARIANT_ORIGINAL
|
||||||
|
);
|
||||||
|
png_data.resize(decoded_size);
|
||||||
|
|
||||||
|
std::filesystem::create_directories("sdbot_img_send");
|
||||||
|
//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
|
||||||
|
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
|
||||||
|
const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
|
||||||
|
|
||||||
|
std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
|
||||||
|
succ = _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nl_pos == std::string_view::npos || nl_pos+1 >= data_str.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
data_str = data_str.substr(nl_pos+1);
|
||||||
|
} while (nl_pos != std::string_view::npos);
|
||||||
|
|
||||||
|
return succ;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
SDBot::SDBot(
|
SDBot::SDBot(
|
||||||
Contact3Registry& cr,
|
Contact3Registry& cr,
|
||||||
RegistryMessageModel& rmm,
|
RegistryMessageModel& rmm,
|
||||||
@ -22,6 +137,26 @@ SDBot::SDBot(
|
|||||||
_rng.seed(std::random_device{}());
|
_rng.seed(std::random_device{}());
|
||||||
_rng.discard(3137);
|
_rng.discard(3137);
|
||||||
|
|
||||||
|
if (!_conf.has_string("SDBot", "endpoint_type")) {
|
||||||
|
_conf.set("SDBot", "endpoint_type", std::string_view{"automatic1111_v1"}); // automatic11 default
|
||||||
|
}
|
||||||
|
|
||||||
|
//HACKy
|
||||||
|
{ // construct endpoint
|
||||||
|
const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value();
|
||||||
|
if (endpoint_type == "automatic1111_v1") {
|
||||||
|
_endpoint = std::make_unique<Automatic1111_v1_Endpoint>(_rmm, _rng);
|
||||||
|
} else if (endpoint_type == "sdcpp_wip1") {
|
||||||
|
_endpoint = std::make_unique<SDcpp_wip1_Endpoint>(_rmm, _rng);
|
||||||
|
} else {
|
||||||
|
std::cerr << "SDB error: unknown endpoint type '" << endpoint_type << "'\n";
|
||||||
|
// TODO: throw?
|
||||||
|
_endpoint = std::make_unique<Automatic1111_v1_Endpoint>(_rmm, _rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: use defaults based on endpoint_type
|
||||||
|
|
||||||
if (!_conf.has_string("SDBot", "server_host")) {
|
if (!_conf.has_string("SDBot", "server_host")) {
|
||||||
_conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"});
|
_conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"});
|
||||||
}
|
}
|
||||||
@ -219,44 +354,10 @@ void SDBot::onHttpComplete(const happyhttp::Response* r) {
|
|||||||
std::cout << "SDB http complete " << r->getstatus() << " " << r->getreason() << "\n";
|
std::cout << "SDB http complete " << r->getstatus() << " " << r->getreason() << "\n";
|
||||||
if (r->getstatus() == happyhttp::OK) {
|
if (r->getstatus() == happyhttp::OK) {
|
||||||
std::cout << "SDB data\n";
|
std::cout << "SDB data\n";
|
||||||
//std::cout << std::string_view{reinterpret_cast<const char*>(_con_data.data()), _con_data.size()} << "\n";
|
|
||||||
|
|
||||||
// extract json result
|
const auto& contact = _task_map.at(_current_task.value());
|
||||||
const auto j = nlohmann::json::parse(
|
if (_endpoint->handleResponse(contact, ByteSpan{_con_data})) {
|
||||||
std::string_view{reinterpret_cast<const char*>(_con_data.data()), _con_data.size()},
|
// error?
|
||||||
nullptr,
|
|
||||||
false
|
|
||||||
);
|
|
||||||
//std::cout << "json dump: " << j.dump() << "\n";
|
|
||||||
|
|
||||||
if (j.count("images") && !j.at("images").empty() && j.at("images").is_array()) {
|
|
||||||
for (const auto& i_j : j.at("images").items()) {
|
|
||||||
// decode data (base64)
|
|
||||||
std::vector<uint8_t> png_data(_con_data.size()); // just init to upper bound
|
|
||||||
size_t decoded_size {0};
|
|
||||||
sodium_base642bin(
|
|
||||||
png_data.data(), png_data.size(),
|
|
||||||
i_j.value().get<std::string>().data(), i_j.value().get<std::string>().size(),
|
|
||||||
" \n\t",
|
|
||||||
&decoded_size,
|
|
||||||
nullptr,
|
|
||||||
sodium_base64_VARIANT_ORIGINAL
|
|
||||||
);
|
|
||||||
png_data.resize(decoded_size);
|
|
||||||
|
|
||||||
// hand png to download manager
|
|
||||||
const auto& contact = _task_map.at(_current_task.value());
|
|
||||||
|
|
||||||
std::filesystem::create_directories("sdbot_img_send");
|
|
||||||
//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
|
|
||||||
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
|
|
||||||
const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
|
|
||||||
|
|
||||||
std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
|
|
||||||
_rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::cerr << "SDB json response did not contain images?\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_task_map.erase(_current_task.value());
|
_task_map.erase(_current_task.value());
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <solanaceae/util/span.hpp>
|
||||||
#include <solanaceae/message3/registry_message_model.hpp>
|
#include <solanaceae/message3/registry_message_model.hpp>
|
||||||
#include <solanaceae/contact/contact_model3.hpp>
|
#include <solanaceae/contact/contact_model3.hpp>
|
||||||
|
|
||||||
@ -33,6 +34,19 @@ class SDBot : public RegistryMessageModelEventI {
|
|||||||
|
|
||||||
std::default_random_engine _rng;
|
std::default_random_engine _rng;
|
||||||
|
|
||||||
|
public:
|
||||||
|
struct EndpointI {
|
||||||
|
RegistryMessageModel& _rmm;
|
||||||
|
std::default_random_engine& _rng;
|
||||||
|
EndpointI(RegistryMessageModel& rmm, std::default_random_engine& rng) : _rmm(rmm), _rng(rng) {}
|
||||||
|
virtual ~EndpointI(void) {}
|
||||||
|
|
||||||
|
virtual bool handleResponse(Contact3 contact, ByteSpan data) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<EndpointI> _endpoint;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SDBot(
|
SDBot(
|
||||||
Contact3Registry& cr,
|
Contact3Registry& cr,
|
||||||
|
Loading…
Reference in New Issue
Block a user