add support for sd.cpp new/merged openai compatible api
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.16 FATAL_ERROR)
|
||||||
|
|
||||||
# cmake setup begin
|
# cmake setup begin
|
||||||
project(solanaceae_sdbot-webui)
|
project(solanaceae_sdbot-webui)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.16 FATAL_ERROR)
|
||||||
|
|
||||||
add_library(solanaceae_sdbot-webui STATIC
|
add_library(solanaceae_sdbot-webui STATIC
|
||||||
./webapi_interface.hpp
|
./webapi_interface.hpp
|
||||||
./webapi_sdcpp_stduhpf_wip2.hpp
|
./webapi_sdcpp_stduhpf_wip2.hpp
|
||||||
./webapi_sdcpp_stduhpf_wip2.cpp
|
./webapi_sdcpp_stduhpf_wip2.cpp
|
||||||
|
./webapi_sdcpp_openai.hpp
|
||||||
|
./webapi_sdcpp_openai.cpp
|
||||||
|
|
||||||
./sd_bot.hpp
|
./sd_bot.hpp
|
||||||
./sd_bot.cpp
|
./sd_bot.cpp
|
||||||
|
|||||||
@@ -7,10 +7,10 @@
|
|||||||
#include <solanaceae/message3/components.hpp>
|
#include <solanaceae/message3/components.hpp>
|
||||||
|
|
||||||
#include "./webapi_sdcpp_stduhpf_wip2.hpp"
|
#include "./webapi_sdcpp_stduhpf_wip2.hpp"
|
||||||
|
#include "./webapi_sdcpp_openai.hpp"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <chrono>
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
@@ -24,7 +24,7 @@ SDBot::SDBot(
|
|||||||
_rng.discard(3137);
|
_rng.discard(3137);
|
||||||
|
|
||||||
if (!_conf.has_string("SDBot", "endpoint_type")) {
|
if (!_conf.has_string("SDBot", "endpoint_type")) {
|
||||||
_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_stduhpf_wip2"}); // default
|
_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_openai"}); // default
|
||||||
}
|
}
|
||||||
|
|
||||||
//HACKy
|
//HACKy
|
||||||
@@ -32,6 +32,8 @@ SDBot::SDBot(
|
|||||||
const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value();
|
const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value();
|
||||||
if (endpoint_type == "sdcpp_stduhpf_wip2") {
|
if (endpoint_type == "sdcpp_stduhpf_wip2") {
|
||||||
_endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf);
|
_endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf);
|
||||||
|
} else if (endpoint_type == "sdcpp_openai") {
|
||||||
|
_endpoint = std::make_unique<WebAPI_sdcpp_openai>(_conf);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("missing endpoint type in config, cant continue!");
|
throw std::runtime_error("missing endpoint type in config, cant continue!");
|
||||||
}
|
}
|
||||||
@@ -42,24 +44,13 @@ SDBot::SDBot(
|
|||||||
if (!_conf.has_string("SDBot", "server_host")) {
|
if (!_conf.has_string("SDBot", "server_host")) {
|
||||||
_conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"});
|
_conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"});
|
||||||
}
|
}
|
||||||
if (!_conf.has_int("SDBot", "server_port")) {
|
|
||||||
_conf.set("SDBot", "server_port", int64_t(7860)); // automatic11 default
|
|
||||||
}
|
|
||||||
if (!_conf.has_string("SDBot", "url_txt2img")) {
|
|
||||||
_conf.set("SDBot", "url_txt2img", std::string_view{"/sdapi/v1/txt2img"}); // automatic11 default
|
|
||||||
}
|
|
||||||
if (!_conf.has_int("SDBot", "width")) {
|
if (!_conf.has_int("SDBot", "width")) {
|
||||||
_conf.set("SDBot", "width", int64_t(512));
|
_conf.set("SDBot", "width", int64_t(512));
|
||||||
}
|
}
|
||||||
if (!_conf.has_int("SDBot", "height")) {
|
if (!_conf.has_int("SDBot", "height")) {
|
||||||
_conf.set("SDBot", "height", int64_t(512));
|
_conf.set("SDBot", "height", int64_t(512));
|
||||||
}
|
}
|
||||||
if (!_conf.has_int("SDBot", "steps")) {
|
|
||||||
_conf.set("SDBot", "steps", int64_t(20));
|
|
||||||
}
|
|
||||||
if (!_conf.has_double("SDBot", "cfg_scale") && !_conf.has_int("SDBot", "cfg_scale")) {
|
|
||||||
_conf.set("SDBot", "cfg_scale", 6.5);
|
|
||||||
}
|
|
||||||
|
|
||||||
_rmm_sr.subscribe(RegistryMessageModel_Event::message_construct);
|
_rmm_sr.subscribe(RegistryMessageModel_Event::message_construct);
|
||||||
}
|
}
|
||||||
|
|||||||
170
src/webapi_sdcpp_openai.cpp
Normal file
170
src/webapi_sdcpp_openai.cpp
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
#include "./webapi_sdcpp_openai.hpp"
|
||||||
|
|
||||||
|
#include <httplib.h>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include <sodium.h>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <chrono>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
|
std::shared_ptr<httplib::Client> WebAPI_sdcpp_openai::getCl(void) {
|
||||||
|
if (_cl == nullptr) {
|
||||||
|
const std::string server_host {_conf.get_string("SDBot", "server_host").value_or("127.0.0.1")};
|
||||||
|
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value_or(1234));
|
||||||
|
_cl->set_read_timeout(std::chrono::minutes(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
return _cl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
WebAPI_sdcpp_openai::WebAPI_sdcpp_openai(ConfigModelI& conf) :
|
||||||
|
_conf(conf)
|
||||||
|
{
|
||||||
|
if (!_conf.has_int("SDBot", "server_port")) {
|
||||||
|
_conf.set("SDBot", "server_port", int64_t(1234));
|
||||||
|
}
|
||||||
|
if (!_conf.has_string("SDBot", "url_txt2img")) {
|
||||||
|
_conf.set("SDBot", "url_txt2img", std::string_view{"/v1/images/generations"});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WebAPI_sdcpp_openai::~WebAPI_sdcpp_openai(void) {
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_openai::txt2img(
|
||||||
|
std::string_view prompt,
|
||||||
|
int16_t width,
|
||||||
|
int16_t height
|
||||||
|
// more
|
||||||
|
) {
|
||||||
|
auto cl = getCl();
|
||||||
|
if (!cl) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::string body;
|
||||||
|
nlohmann::json j_body;
|
||||||
|
// eg
|
||||||
|
//{
|
||||||
|
// "model": "unused",
|
||||||
|
// "prompt": "A lovely cat<sd_cpp_extra_args>{\"seed\": 357925}</sd_cpp_extra_args>",
|
||||||
|
// "n": 1,
|
||||||
|
// "size": "128x128",
|
||||||
|
// "response_format": "b64_json"
|
||||||
|
//}
|
||||||
|
try {
|
||||||
|
j_body["model"] = "unused";
|
||||||
|
j_body["size"] = std::to_string(width) + "x" + std::to_string(height);
|
||||||
|
j_body["n"] = 1;
|
||||||
|
j_body["response_format"] = "b64_json";
|
||||||
|
|
||||||
|
nlohmann::json j_sd_extra_args;
|
||||||
|
j_sd_extra_args["seed"] = -1;
|
||||||
|
if (_conf.has_int("SDBot", "steps")) {
|
||||||
|
j_sd_extra_args["steps"] = _conf.get_int("SDBot", "steps").value();
|
||||||
|
}
|
||||||
|
if (_conf.has_double("SDBot", "cfg_scale")) {
|
||||||
|
j_sd_extra_args["cfg_scale"] = _conf.get_double("SDBot", "steps").value();
|
||||||
|
} else if (_conf.has_int("SDBot", "cfg_scale")) {
|
||||||
|
j_sd_extra_args["cfg_scale"] = _conf.get_int("SDBot", "steps").value();
|
||||||
|
}
|
||||||
|
|
||||||
|
j_body["prompt"] =
|
||||||
|
std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")}
|
||||||
|
+ std::string{prompt}
|
||||||
|
+ "<sd_cpp_extra_args>"
|
||||||
|
+ j_sd_extra_args.dump(-1, ' ', true)
|
||||||
|
+ "</sd_cpp_extra_args>"
|
||||||
|
;
|
||||||
|
|
||||||
|
body = j_body.dump();
|
||||||
|
} catch (...) {
|
||||||
|
std::cerr << "SDB error: failed creating body json\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//std::cout << "body: " << body << "\n";
|
||||||
|
|
||||||
|
try {
|
||||||
|
const std::string url {_conf.get_string("SDBot", "url_txt2img").value_or("/v1/images/generations")};
|
||||||
|
return std::make_shared<WebAPITask_sdcpp_openai>(url, body, getCl());
|
||||||
|
} catch (...) {
|
||||||
|
std::cerr << "SDB error: scheduling post to sd server failed!\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
WebAPITask_sdcpp_openai::WebAPITask_sdcpp_openai(const std::string& url, const std::string& body, std::shared_ptr<httplib::Client> cl) {
|
||||||
|
_future = std::async(std::launch::async, [url, body, cl]() -> WebAPITaskI::Result {
|
||||||
|
if (!static_cast<bool>(cl)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
auto res = cl->Post(url, body, "application/json");
|
||||||
|
if (!static_cast<bool>(res)) {
|
||||||
|
std::cerr << "SDB error: post to sd server failed!\n";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::cerr << "SDB: http complete " << res->status << " " << res->reason << "\n";
|
||||||
|
|
||||||
|
if (
|
||||||
|
res.error() != httplib::Error::Success ||
|
||||||
|
res->status != 200
|
||||||
|
) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
//std::cerr << "------ res body: " << res->body << "\n";
|
||||||
|
|
||||||
|
const auto j_res = nlohmann::json::parse(res->body);
|
||||||
|
|
||||||
|
// TODO: add support for multiple images
|
||||||
|
auto& j_image = j_res.at("data").at(0);
|
||||||
|
|
||||||
|
Result result;
|
||||||
|
|
||||||
|
auto& j_data = j_image.at("b64_json");
|
||||||
|
result.data.resize(j_data.get<std::string>().size()); // HACK: do a better estimate
|
||||||
|
size_t decoded_size {0};
|
||||||
|
sodium_base642bin(
|
||||||
|
result.data.data(), result.data.size(),
|
||||||
|
j_data.get<std::string>().data(), j_data.get<std::string>().size(),
|
||||||
|
" \n\t",
|
||||||
|
&decoded_size,
|
||||||
|
nullptr,
|
||||||
|
sodium_base64_VARIANT_ORIGINAL
|
||||||
|
);
|
||||||
|
result.data.resize(decoded_size);
|
||||||
|
|
||||||
|
result.file_name = std::string{"output_0."} + j_image.value("output_format", "png");
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} catch (...) {
|
||||||
|
std::cerr << "SDB error: post to sd server failed!\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!_future.valid()) {
|
||||||
|
throw std::runtime_error("failed to create future");
|
||||||
|
std::cerr << "SDB error: scheduling post to sd server failed, invalid future returned.\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WebAPITask_sdcpp_openai::ready(void) {
|
||||||
|
return _future.wait_for(std::chrono::microseconds(10)) == std::future_status::ready;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<WebAPITaskI::Result> WebAPITask_sdcpp_openai::get(void) {
|
||||||
|
return _future.get();
|
||||||
|
}
|
||||||
|
|
||||||
47
src/webapi_sdcpp_openai.hpp
Normal file
47
src/webapi_sdcpp_openai.hpp
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "./webapi_interface.hpp"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <solanaceae/util/config_model.hpp>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
|
// fwd
|
||||||
|
namespace httplib {
|
||||||
|
class Client;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is supposedly an openai spec compatible one, in master
|
||||||
|
struct WebAPI_sdcpp_openai : public WebAPII {
|
||||||
|
// TODO: const config
|
||||||
|
ConfigModelI& _conf;
|
||||||
|
|
||||||
|
std::shared_ptr<httplib::Client> _cl;
|
||||||
|
std::shared_ptr<httplib::Client> getCl(void);
|
||||||
|
|
||||||
|
WebAPI_sdcpp_openai(ConfigModelI& conf);
|
||||||
|
|
||||||
|
~WebAPI_sdcpp_openai(void) override;
|
||||||
|
|
||||||
|
std::shared_ptr<WebAPITaskI> txt2img(
|
||||||
|
std::string_view prompt,
|
||||||
|
int16_t width,
|
||||||
|
int16_t height
|
||||||
|
// more
|
||||||
|
) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// discard after get() !!
|
||||||
|
struct WebAPITask_sdcpp_openai : public WebAPITaskI {
|
||||||
|
std::future<Result> _future;
|
||||||
|
|
||||||
|
WebAPITask_sdcpp_openai(const std::string& url, const std::string& body, std::shared_ptr<httplib::Client> cl);
|
||||||
|
~WebAPITask_sdcpp_openai(void) override {}
|
||||||
|
|
||||||
|
bool ready(void) override;
|
||||||
|
std::optional<Result> get(void) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -11,8 +11,8 @@
|
|||||||
|
|
||||||
std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
|
std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
|
||||||
if (_cl == nullptr) {
|
if (_cl == nullptr) {
|
||||||
const std::string server_host {_conf.get_string("SDBot", "server_host").value()};
|
const std::string server_host {_conf.get_string("SDBot", "server_host").value_or("127.0.0.1")};
|
||||||
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value());
|
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value_or(8080));
|
||||||
_cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while
|
_cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,6 +22,9 @@ std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
|
|||||||
WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) :
|
WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) :
|
||||||
_conf(conf)
|
_conf(conf)
|
||||||
{
|
{
|
||||||
|
if (!_conf.has_int("SDBot", "server_port")) {
|
||||||
|
_conf.set("SDBot", "server_port", int64_t(8080));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) {
|
WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) {
|
||||||
@@ -67,9 +70,7 @@ std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
|
|||||||
// "vae_tiling": true,
|
// "vae_tiling": true,
|
||||||
// "tae_decode": true
|
// "tae_decode": true
|
||||||
//}
|
//}
|
||||||
//j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
|
|
||||||
j_body["width"] = width;
|
j_body["width"] = width;
|
||||||
//j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
|
|
||||||
j_body["height"] = height;
|
j_body["height"] = height;
|
||||||
|
|
||||||
j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt};
|
j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt};
|
||||||
@@ -92,7 +93,7 @@ std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n";
|
std::cerr << "SDB: http complete " << res->status << " " << res->reason << "\n";
|
||||||
|
|
||||||
if (
|
if (
|
||||||
res.error() != httplib::Error::Success ||
|
res.error() != httplib::Error::Success ||
|
||||||
|
|||||||
Reference in New Issue
Block a user