add support for sd.cpp new/merged openai compatible api

This commit is contained in:
Green Sky
2025-12-16 23:03:26 +01:00
parent c9dd7259f2
commit 9090c60d62
6 changed files with 232 additions and 21 deletions

View File

@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.9 FATAL_ERROR) cmake_minimum_required(VERSION 3.16 FATAL_ERROR)
# cmake setup begin # cmake setup begin
project(solanaceae_sdbot-webui) project(solanaceae_sdbot-webui)

View File

@@ -1,9 +1,11 @@
cmake_minimum_required(VERSION 3.9 FATAL_ERROR) cmake_minimum_required(VERSION 3.16 FATAL_ERROR)
add_library(solanaceae_sdbot-webui STATIC add_library(solanaceae_sdbot-webui STATIC
./webapi_interface.hpp ./webapi_interface.hpp
./webapi_sdcpp_stduhpf_wip2.hpp ./webapi_sdcpp_stduhpf_wip2.hpp
./webapi_sdcpp_stduhpf_wip2.cpp ./webapi_sdcpp_stduhpf_wip2.cpp
./webapi_sdcpp_openai.hpp
./webapi_sdcpp_openai.cpp
./sd_bot.hpp ./sd_bot.hpp
./sd_bot.cpp ./sd_bot.cpp

View File

@@ -7,10 +7,10 @@
#include <solanaceae/message3/components.hpp> #include <solanaceae/message3/components.hpp>
#include "./webapi_sdcpp_stduhpf_wip2.hpp" #include "./webapi_sdcpp_stduhpf_wip2.hpp"
#include "./webapi_sdcpp_openai.hpp"
#include <fstream> #include <fstream>
#include <filesystem> #include <filesystem>
#include <chrono>
#include <iostream> #include <iostream>
#include <stdexcept> #include <stdexcept>
@@ -24,7 +24,7 @@ SDBot::SDBot(
_rng.discard(3137); _rng.discard(3137);
if (!_conf.has_string("SDBot", "endpoint_type")) { if (!_conf.has_string("SDBot", "endpoint_type")) {
_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_stduhpf_wip2"}); // default _conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_openai"}); // default
} }
//HACKy //HACKy
@@ -32,6 +32,8 @@ SDBot::SDBot(
const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value(); const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value();
if (endpoint_type == "sdcpp_stduhpf_wip2") { if (endpoint_type == "sdcpp_stduhpf_wip2") {
_endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf); _endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf);
} else if (endpoint_type == "sdcpp_openai") {
_endpoint = std::make_unique<WebAPI_sdcpp_openai>(_conf);
} else { } else {
throw std::runtime_error("missing endpoint type in config, cant continue!"); throw std::runtime_error("missing endpoint type in config, cant continue!");
} }
@@ -42,24 +44,13 @@ SDBot::SDBot(
if (!_conf.has_string("SDBot", "server_host")) { if (!_conf.has_string("SDBot", "server_host")) {
_conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"}); _conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"});
} }
if (!_conf.has_int("SDBot", "server_port")) {
_conf.set("SDBot", "server_port", int64_t(7860)); // automatic11 default
}
if (!_conf.has_string("SDBot", "url_txt2img")) {
_conf.set("SDBot", "url_txt2img", std::string_view{"/sdapi/v1/txt2img"}); // automatic11 default
}
if (!_conf.has_int("SDBot", "width")) { if (!_conf.has_int("SDBot", "width")) {
_conf.set("SDBot", "width", int64_t(512)); _conf.set("SDBot", "width", int64_t(512));
} }
if (!_conf.has_int("SDBot", "height")) { if (!_conf.has_int("SDBot", "height")) {
_conf.set("SDBot", "height", int64_t(512)); _conf.set("SDBot", "height", int64_t(512));
} }
if (!_conf.has_int("SDBot", "steps")) {
_conf.set("SDBot", "steps", int64_t(20));
}
if (!_conf.has_double("SDBot", "cfg_scale") && !_conf.has_int("SDBot", "cfg_scale")) {
_conf.set("SDBot", "cfg_scale", 6.5);
}
_rmm_sr.subscribe(RegistryMessageModel_Event::message_construct); _rmm_sr.subscribe(RegistryMessageModel_Event::message_construct);
} }

170
src/webapi_sdcpp_openai.cpp Normal file
View File

@@ -0,0 +1,170 @@
#include "./webapi_sdcpp_openai.hpp"
#include <httplib.h>
#include <nlohmann/json.hpp>
#include <sodium.h>
#include <optional>
#include <stdexcept>
#include <string>
#include <chrono>
#include <future>
std::shared_ptr<httplib::Client> WebAPI_sdcpp_openai::getCl(void) {
if (_cl == nullptr) {
const std::string server_host {_conf.get_string("SDBot", "server_host").value_or("127.0.0.1")};
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value_or(1234));
_cl->set_read_timeout(std::chrono::minutes(10));
}
return _cl;
}
WebAPI_sdcpp_openai::WebAPI_sdcpp_openai(ConfigModelI& conf) :
_conf(conf)
{
if (!_conf.has_int("SDBot", "server_port")) {
_conf.set("SDBot", "server_port", int64_t(1234));
}
if (!_conf.has_string("SDBot", "url_txt2img")) {
_conf.set("SDBot", "url_txt2img", std::string_view{"/v1/images/generations"});
}
}
WebAPI_sdcpp_openai::~WebAPI_sdcpp_openai(void) {
}
std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_openai::txt2img(
std::string_view prompt,
int16_t width,
int16_t height
// more
) {
auto cl = getCl();
if (!cl) {
return nullptr;
}
std::string body;
nlohmann::json j_body;
// eg
//{
// "model": "unused",
// "prompt": "A lovely cat<sd_cpp_extra_args>{\"seed\": 357925}</sd_cpp_extra_args>",
// "n": 1,
// "size": "128x128",
// "response_format": "b64_json"
//}
try {
j_body["model"] = "unused";
j_body["size"] = std::to_string(width) + "x" + std::to_string(height);
j_body["n"] = 1;
j_body["response_format"] = "b64_json";
nlohmann::json j_sd_extra_args;
j_sd_extra_args["seed"] = -1;
if (_conf.has_int("SDBot", "steps")) {
j_sd_extra_args["steps"] = _conf.get_int("SDBot", "steps").value();
}
if (_conf.has_double("SDBot", "cfg_scale")) {
j_sd_extra_args["cfg_scale"] = _conf.get_double("SDBot", "steps").value();
} else if (_conf.has_int("SDBot", "cfg_scale")) {
j_sd_extra_args["cfg_scale"] = _conf.get_int("SDBot", "steps").value();
}
j_body["prompt"] =
std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")}
+ std::string{prompt}
+ "<sd_cpp_extra_args>"
+ j_sd_extra_args.dump(-1, ' ', true)
+ "</sd_cpp_extra_args>"
;
body = j_body.dump();
} catch (...) {
std::cerr << "SDB error: failed creating body json\n";
return nullptr;
}
//std::cout << "body: " << body << "\n";
try {
const std::string url {_conf.get_string("SDBot", "url_txt2img").value_or("/v1/images/generations")};
return std::make_shared<WebAPITask_sdcpp_openai>(url, body, getCl());
} catch (...) {
std::cerr << "SDB error: scheduling post to sd server failed!\n";
return nullptr;
}
return nullptr;
}
WebAPITask_sdcpp_openai::WebAPITask_sdcpp_openai(const std::string& url, const std::string& body, std::shared_ptr<httplib::Client> cl) {
_future = std::async(std::launch::async, [url, body, cl]() -> WebAPITaskI::Result {
if (!static_cast<bool>(cl)) {
return {};
}
try {
auto res = cl->Post(url, body, "application/json");
if (!static_cast<bool>(res)) {
std::cerr << "SDB error: post to sd server failed!\n";
return {};
}
std::cerr << "SDB: http complete " << res->status << " " << res->reason << "\n";
if (
res.error() != httplib::Error::Success ||
res->status != 200
) {
return {};
}
//std::cerr << "------ res body: " << res->body << "\n";
const auto j_res = nlohmann::json::parse(res->body);
// TODO: add support for multiple images
auto& j_image = j_res.at("data").at(0);
Result result;
auto& j_data = j_image.at("b64_json");
result.data.resize(j_data.get<std::string>().size()); // HACK: do a better estimate
size_t decoded_size {0};
sodium_base642bin(
result.data.data(), result.data.size(),
j_data.get<std::string>().data(), j_data.get<std::string>().size(),
" \n\t",
&decoded_size,
nullptr,
sodium_base64_VARIANT_ORIGINAL
);
result.data.resize(decoded_size);
result.file_name = std::string{"output_0."} + j_image.value("output_format", "png");
return result;
} catch (...) {
std::cerr << "SDB error: post to sd server failed!\n";
}
return {};
});
if (!_future.valid()) {
throw std::runtime_error("failed to create future");
std::cerr << "SDB error: scheduling post to sd server failed, invalid future returned.\n";
}
}
bool WebAPITask_sdcpp_openai::ready(void) {
return _future.wait_for(std::chrono::microseconds(10)) == std::future_status::ready;
}
std::optional<WebAPITaskI::Result> WebAPITask_sdcpp_openai::get(void) {
return _future.get();
}

View File

@@ -0,0 +1,47 @@
#pragma once
#include "./webapi_interface.hpp"
#include <cstdint>
#include <solanaceae/util/config_model.hpp>
#include <memory>
#include <future>
// fwd
namespace httplib {
class Client;
}
// this is supposedly an openai spec compatible one, in master
struct WebAPI_sdcpp_openai : public WebAPII {
// TODO: const config
ConfigModelI& _conf;
std::shared_ptr<httplib::Client> _cl;
std::shared_ptr<httplib::Client> getCl(void);
WebAPI_sdcpp_openai(ConfigModelI& conf);
~WebAPI_sdcpp_openai(void) override;
std::shared_ptr<WebAPITaskI> txt2img(
std::string_view prompt,
int16_t width,
int16_t height
// more
) override;
};
// discard after get() !!
struct WebAPITask_sdcpp_openai : public WebAPITaskI {
std::future<Result> _future;
WebAPITask_sdcpp_openai(const std::string& url, const std::string& body, std::shared_ptr<httplib::Client> cl);
~WebAPITask_sdcpp_openai(void) override {}
bool ready(void) override;
std::optional<Result> get(void) override;
};

View File

@@ -11,8 +11,8 @@
std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) { std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
if (_cl == nullptr) { if (_cl == nullptr) {
const std::string server_host {_conf.get_string("SDBot", "server_host").value()}; const std::string server_host {_conf.get_string("SDBot", "server_host").value_or("127.0.0.1")};
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value()); _cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value_or(8080));
_cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while _cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while
} }
@@ -22,6 +22,9 @@ std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) : WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) :
_conf(conf) _conf(conf)
{ {
if (!_conf.has_int("SDBot", "server_port")) {
_conf.set("SDBot", "server_port", int64_t(8080));
}
} }
WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) { WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) {
@@ -67,9 +70,7 @@ std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
// "vae_tiling": true, // "vae_tiling": true,
// "tae_decode": true // "tae_decode": true
//} //}
//j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
j_body["width"] = width; j_body["width"] = width;
//j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
j_body["height"] = height; j_body["height"] = height;
j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt}; j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt};
@@ -92,7 +93,7 @@ std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
return nullptr; return nullptr;
} }
std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n"; std::cerr << "SDB: http complete " << res->status << " " << res->reason << "\n";
if ( if (
res.error() != httplib::Error::Success || res.error() != httplib::Error::Success ||