add support for sd.cpp new/merged openai compatible api
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
|
||||
cmake_minimum_required(VERSION 3.16 FATAL_ERROR)
|
||||
|
||||
# cmake setup begin
|
||||
project(solanaceae_sdbot-webui)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
|
||||
cmake_minimum_required(VERSION 3.16 FATAL_ERROR)
|
||||
|
||||
add_library(solanaceae_sdbot-webui STATIC
|
||||
./webapi_interface.hpp
|
||||
./webapi_sdcpp_stduhpf_wip2.hpp
|
||||
./webapi_sdcpp_stduhpf_wip2.cpp
|
||||
./webapi_sdcpp_openai.hpp
|
||||
./webapi_sdcpp_openai.cpp
|
||||
|
||||
./sd_bot.hpp
|
||||
./sd_bot.cpp
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
#include <solanaceae/message3/components.hpp>
|
||||
|
||||
#include "./webapi_sdcpp_stduhpf_wip2.hpp"
|
||||
#include "./webapi_sdcpp_openai.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include <chrono>
|
||||
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
@@ -24,7 +24,7 @@ SDBot::SDBot(
|
||||
_rng.discard(3137);
|
||||
|
||||
if (!_conf.has_string("SDBot", "endpoint_type")) {
|
||||
_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_stduhpf_wip2"}); // default
|
||||
_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_openai"}); // default
|
||||
}
|
||||
|
||||
//HACKy
|
||||
@@ -32,6 +32,8 @@ SDBot::SDBot(
|
||||
const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value();
|
||||
if (endpoint_type == "sdcpp_stduhpf_wip2") {
|
||||
_endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf);
|
||||
} else if (endpoint_type == "sdcpp_openai") {
|
||||
_endpoint = std::make_unique<WebAPI_sdcpp_openai>(_conf);
|
||||
} else {
|
||||
throw std::runtime_error("missing endpoint type in config, cant continue!");
|
||||
}
|
||||
@@ -42,24 +44,13 @@ SDBot::SDBot(
|
||||
if (!_conf.has_string("SDBot", "server_host")) {
|
||||
_conf.set("SDBot", "server_host", std::string_view{"127.0.0.1"});
|
||||
}
|
||||
if (!_conf.has_int("SDBot", "server_port")) {
|
||||
_conf.set("SDBot", "server_port", int64_t(7860)); // automatic11 default
|
||||
}
|
||||
if (!_conf.has_string("SDBot", "url_txt2img")) {
|
||||
_conf.set("SDBot", "url_txt2img", std::string_view{"/sdapi/v1/txt2img"}); // automatic11 default
|
||||
}
|
||||
|
||||
if (!_conf.has_int("SDBot", "width")) {
|
||||
_conf.set("SDBot", "width", int64_t(512));
|
||||
}
|
||||
if (!_conf.has_int("SDBot", "height")) {
|
||||
_conf.set("SDBot", "height", int64_t(512));
|
||||
}
|
||||
if (!_conf.has_int("SDBot", "steps")) {
|
||||
_conf.set("SDBot", "steps", int64_t(20));
|
||||
}
|
||||
if (!_conf.has_double("SDBot", "cfg_scale") && !_conf.has_int("SDBot", "cfg_scale")) {
|
||||
_conf.set("SDBot", "cfg_scale", 6.5);
|
||||
}
|
||||
|
||||
_rmm_sr.subscribe(RegistryMessageModel_Event::message_construct);
|
||||
}
|
||||
|
||||
170
src/webapi_sdcpp_openai.cpp
Normal file
170
src/webapi_sdcpp_openai.cpp
Normal file
@@ -0,0 +1,170 @@
|
||||
#include "./webapi_sdcpp_openai.hpp"
|
||||
|
||||
#include <httplib.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <sodium.h>
|
||||
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
|
||||
std::shared_ptr<httplib::Client> WebAPI_sdcpp_openai::getCl(void) {
|
||||
if (_cl == nullptr) {
|
||||
const std::string server_host {_conf.get_string("SDBot", "server_host").value_or("127.0.0.1")};
|
||||
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value_or(1234));
|
||||
_cl->set_read_timeout(std::chrono::minutes(10));
|
||||
}
|
||||
|
||||
return _cl;
|
||||
}
|
||||
|
||||
|
||||
WebAPI_sdcpp_openai::WebAPI_sdcpp_openai(ConfigModelI& conf) :
|
||||
_conf(conf)
|
||||
{
|
||||
if (!_conf.has_int("SDBot", "server_port")) {
|
||||
_conf.set("SDBot", "server_port", int64_t(1234));
|
||||
}
|
||||
if (!_conf.has_string("SDBot", "url_txt2img")) {
|
||||
_conf.set("SDBot", "url_txt2img", std::string_view{"/v1/images/generations"});
|
||||
}
|
||||
}
|
||||
|
||||
WebAPI_sdcpp_openai::~WebAPI_sdcpp_openai(void) {
|
||||
}
|
||||
|
||||
std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_openai::txt2img(
|
||||
std::string_view prompt,
|
||||
int16_t width,
|
||||
int16_t height
|
||||
// more
|
||||
) {
|
||||
auto cl = getCl();
|
||||
if (!cl) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
std::string body;
|
||||
nlohmann::json j_body;
|
||||
// eg
|
||||
//{
|
||||
// "model": "unused",
|
||||
// "prompt": "A lovely cat<sd_cpp_extra_args>{\"seed\": 357925}</sd_cpp_extra_args>",
|
||||
// "n": 1,
|
||||
// "size": "128x128",
|
||||
// "response_format": "b64_json"
|
||||
//}
|
||||
try {
|
||||
j_body["model"] = "unused";
|
||||
j_body["size"] = std::to_string(width) + "x" + std::to_string(height);
|
||||
j_body["n"] = 1;
|
||||
j_body["response_format"] = "b64_json";
|
||||
|
||||
nlohmann::json j_sd_extra_args;
|
||||
j_sd_extra_args["seed"] = -1;
|
||||
if (_conf.has_int("SDBot", "steps")) {
|
||||
j_sd_extra_args["steps"] = _conf.get_int("SDBot", "steps").value();
|
||||
}
|
||||
if (_conf.has_double("SDBot", "cfg_scale")) {
|
||||
j_sd_extra_args["cfg_scale"] = _conf.get_double("SDBot", "steps").value();
|
||||
} else if (_conf.has_int("SDBot", "cfg_scale")) {
|
||||
j_sd_extra_args["cfg_scale"] = _conf.get_int("SDBot", "steps").value();
|
||||
}
|
||||
|
||||
j_body["prompt"] =
|
||||
std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")}
|
||||
+ std::string{prompt}
|
||||
+ "<sd_cpp_extra_args>"
|
||||
+ j_sd_extra_args.dump(-1, ' ', true)
|
||||
+ "</sd_cpp_extra_args>"
|
||||
;
|
||||
|
||||
body = j_body.dump();
|
||||
} catch (...) {
|
||||
std::cerr << "SDB error: failed creating body json\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//std::cout << "body: " << body << "\n";
|
||||
|
||||
try {
|
||||
const std::string url {_conf.get_string("SDBot", "url_txt2img").value_or("/v1/images/generations")};
|
||||
return std::make_shared<WebAPITask_sdcpp_openai>(url, body, getCl());
|
||||
} catch (...) {
|
||||
std::cerr << "SDB error: scheduling post to sd server failed!\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
WebAPITask_sdcpp_openai::WebAPITask_sdcpp_openai(const std::string& url, const std::string& body, std::shared_ptr<httplib::Client> cl) {
|
||||
_future = std::async(std::launch::async, [url, body, cl]() -> WebAPITaskI::Result {
|
||||
if (!static_cast<bool>(cl)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
auto res = cl->Post(url, body, "application/json");
|
||||
if (!static_cast<bool>(res)) {
|
||||
std::cerr << "SDB error: post to sd server failed!\n";
|
||||
return {};
|
||||
}
|
||||
std::cerr << "SDB: http complete " << res->status << " " << res->reason << "\n";
|
||||
|
||||
if (
|
||||
res.error() != httplib::Error::Success ||
|
||||
res->status != 200
|
||||
) {
|
||||
return {};
|
||||
}
|
||||
|
||||
//std::cerr << "------ res body: " << res->body << "\n";
|
||||
|
||||
const auto j_res = nlohmann::json::parse(res->body);
|
||||
|
||||
// TODO: add support for multiple images
|
||||
auto& j_image = j_res.at("data").at(0);
|
||||
|
||||
Result result;
|
||||
|
||||
auto& j_data = j_image.at("b64_json");
|
||||
result.data.resize(j_data.get<std::string>().size()); // HACK: do a better estimate
|
||||
size_t decoded_size {0};
|
||||
sodium_base642bin(
|
||||
result.data.data(), result.data.size(),
|
||||
j_data.get<std::string>().data(), j_data.get<std::string>().size(),
|
||||
" \n\t",
|
||||
&decoded_size,
|
||||
nullptr,
|
||||
sodium_base64_VARIANT_ORIGINAL
|
||||
);
|
||||
result.data.resize(decoded_size);
|
||||
|
||||
result.file_name = std::string{"output_0."} + j_image.value("output_format", "png");
|
||||
|
||||
return result;
|
||||
} catch (...) {
|
||||
std::cerr << "SDB error: post to sd server failed!\n";
|
||||
}
|
||||
|
||||
return {};
|
||||
});
|
||||
|
||||
if (!_future.valid()) {
|
||||
throw std::runtime_error("failed to create future");
|
||||
std::cerr << "SDB error: scheduling post to sd server failed, invalid future returned.\n";
|
||||
}
|
||||
}
|
||||
|
||||
bool WebAPITask_sdcpp_openai::ready(void) {
|
||||
return _future.wait_for(std::chrono::microseconds(10)) == std::future_status::ready;
|
||||
}
|
||||
|
||||
std::optional<WebAPITaskI::Result> WebAPITask_sdcpp_openai::get(void) {
|
||||
return _future.get();
|
||||
}
|
||||
|
||||
47
src/webapi_sdcpp_openai.hpp
Normal file
47
src/webapi_sdcpp_openai.hpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include "./webapi_interface.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <solanaceae/util/config_model.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <future>
|
||||
|
||||
// fwd
|
||||
namespace httplib {
|
||||
class Client;
|
||||
}
|
||||
|
||||
// this is supposedly an openai spec compatible one, in master
|
||||
struct WebAPI_sdcpp_openai : public WebAPII {
|
||||
// TODO: const config
|
||||
ConfigModelI& _conf;
|
||||
|
||||
std::shared_ptr<httplib::Client> _cl;
|
||||
std::shared_ptr<httplib::Client> getCl(void);
|
||||
|
||||
WebAPI_sdcpp_openai(ConfigModelI& conf);
|
||||
|
||||
~WebAPI_sdcpp_openai(void) override;
|
||||
|
||||
std::shared_ptr<WebAPITaskI> txt2img(
|
||||
std::string_view prompt,
|
||||
int16_t width,
|
||||
int16_t height
|
||||
// more
|
||||
) override;
|
||||
};
|
||||
|
||||
// discard after get() !!
|
||||
struct WebAPITask_sdcpp_openai : public WebAPITaskI {
|
||||
std::future<Result> _future;
|
||||
|
||||
WebAPITask_sdcpp_openai(const std::string& url, const std::string& body, std::shared_ptr<httplib::Client> cl);
|
||||
~WebAPITask_sdcpp_openai(void) override {}
|
||||
|
||||
bool ready(void) override;
|
||||
std::optional<Result> get(void) override;
|
||||
};
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
|
||||
std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
|
||||
if (_cl == nullptr) {
|
||||
const std::string server_host {_conf.get_string("SDBot", "server_host").value()};
|
||||
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value());
|
||||
const std::string server_host {_conf.get_string("SDBot", "server_host").value_or("127.0.0.1")};
|
||||
_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value_or(8080));
|
||||
_cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while
|
||||
}
|
||||
|
||||
@@ -22,6 +22,9 @@ std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
|
||||
WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) :
|
||||
_conf(conf)
|
||||
{
|
||||
if (!_conf.has_int("SDBot", "server_port")) {
|
||||
_conf.set("SDBot", "server_port", int64_t(8080));
|
||||
}
|
||||
}
|
||||
|
||||
WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) {
|
||||
@@ -67,9 +70,7 @@ std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
|
||||
// "vae_tiling": true,
|
||||
// "tae_decode": true
|
||||
//}
|
||||
//j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
|
||||
j_body["width"] = width;
|
||||
//j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
|
||||
j_body["height"] = height;
|
||||
|
||||
j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt};
|
||||
@@ -92,7 +93,7 @@ std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n";
|
||||
std::cerr << "SDB: http complete " << res->status << " " << res->reason << "\n";
|
||||
|
||||
if (
|
||||
res.error() != httplib::Error::Success ||
|
||||
|
||||
Reference in New Issue
Block a user