Compare commits
	
		
			3 Commits
		
	
	
		
			b3315da1d9
			...
			2beb74eb5f
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					2beb74eb5f | ||
| 
						 | 
					bfd923e829 | ||
| 
						 | 
					24ae710c29 | 
							
								
								
									
										4
									
								
								.github/workflows/cd.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/cd.yml
									
									
									
									
										vendored
									
									
								
							@@ -14,7 +14,7 @@ jobs:
 | 
			
		||||
  linux-ubuntu:
 | 
			
		||||
    timeout-minutes: 10
 | 
			
		||||
 | 
			
		||||
    runs-on: ubuntu-20.04
 | 
			
		||||
    runs-on: ubuntu-24.04
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v4
 | 
			
		||||
@@ -32,7 +32,7 @@ jobs:
 | 
			
		||||
 | 
			
		||||
    - uses: actions/upload-artifact@v4
 | 
			
		||||
      with:
 | 
			
		||||
        name: ${{ github.event.repository.name }}-ubuntu20.04-x86_64
 | 
			
		||||
        name: ${{ github.event.repository.name }}-ubuntu24.04-x86_64
 | 
			
		||||
        # TODO: do propper packing
 | 
			
		||||
        path: |
 | 
			
		||||
          ${{github.workspace}}/build/bin/
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
## Solanaceae extention and plugin to serve StableDiffusion
 | 
			
		||||
 | 
			
		||||
!! currently only works with automatic1111's api !!
 | 
			
		||||
!! currently only works with [stduhpf's stablediffusion.cpp http server api](https://github.com/leejet/stable-diffusion.cpp/pull/367)  !!
 | 
			
		||||
 | 
			
		||||
### example config for `totato`
 | 
			
		||||
```json
 | 
			
		||||
@@ -15,15 +15,15 @@
 | 
			
		||||
	"SDBot": {
 | 
			
		||||
		"server_host": "127.0.0.1",
 | 
			
		||||
		"server_port": 8080,
 | 
			
		||||
		"url_txt2img": "/sdapi/v1/txt2img",
 | 
			
		||||
		"endpoint_type": "sdcpp_stduhpf_wip2",
 | 
			
		||||
		"url_txt2img": "/txt2img",
 | 
			
		||||
 | 
			
		||||
		"width": 512,
 | 
			
		||||
		"height": 512,
 | 
			
		||||
 | 
			
		||||
		"prompt_prefix": "<lora:lcm-lora-sdv1-5:1>",
 | 
			
		||||
		"steps": 8,
 | 
			
		||||
		"cfg_scale": 1.0,
 | 
			
		||||
		"sampler_index": "LCM Test"
 | 
			
		||||
		"cfg_scale": 1.0
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								external/CMakeLists.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								external/CMakeLists.txt
									
									
									
									
										vendored
									
									
								
							@@ -78,7 +78,7 @@ endif()
 | 
			
		||||
if (NOT TARGET httplib::httplib)
 | 
			
		||||
	FetchContent_Declare(httplib
 | 
			
		||||
		GIT_REPOSITORY https://github.com/yhirose/cpp-httplib.git
 | 
			
		||||
		GIT_TAG v0.19.0
 | 
			
		||||
		GIT_TAG v0.22.0
 | 
			
		||||
		EXCLUDE_FROM_ALL
 | 
			
		||||
	)
 | 
			
		||||
	FetchContent_MakeAvailable(httplib)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,10 @@
 | 
			
		||||
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
 | 
			
		||||
 | 
			
		||||
add_library(solanaceae_sdbot-webui STATIC
 | 
			
		||||
	./webapi_interface.hpp
 | 
			
		||||
	./webapi_sdcpp_stduhpf_wip2.hpp
 | 
			
		||||
	./webapi_sdcpp_stduhpf_wip2.cpp
 | 
			
		||||
 | 
			
		||||
	./sd_bot.hpp
 | 
			
		||||
	./sd_bot.cpp
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										328
									
								
								src/sd_bot.cpp
									
									
									
									
									
								
							
							
						
						
									
										328
									
								
								src/sd_bot.cpp
									
									
									
									
									
								
							@@ -6,200 +6,14 @@
 | 
			
		||||
#include <solanaceae/contact/components.hpp>
 | 
			
		||||
#include <solanaceae/message3/components.hpp>
 | 
			
		||||
 | 
			
		||||
#include <nlohmann/json.hpp>
 | 
			
		||||
#include <sodium.h>
 | 
			
		||||
#include "./webapi_sdcpp_stduhpf_wip2.hpp"
 | 
			
		||||
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <filesystem>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
 | 
			
		||||
struct Automatic1111_v1_Endpoint : public SDBot::EndpointI {
 | 
			
		||||
	Automatic1111_v1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {}
 | 
			
		||||
 | 
			
		||||
	bool handleResponse(Contact4 contact, ByteSpan data) override {
 | 
			
		||||
		//std::cout << std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} << "\n";
 | 
			
		||||
 | 
			
		||||
		// extract json result
 | 
			
		||||
		const auto j = nlohmann::json::parse(
 | 
			
		||||
			std::string_view{reinterpret_cast<const char*>(data.ptr), data.size},
 | 
			
		||||
			nullptr,
 | 
			
		||||
			false
 | 
			
		||||
		);
 | 
			
		||||
		//std::cout << "json dump: " << j.dump() << "\n";
 | 
			
		||||
 | 
			
		||||
		if (j.count("images") && !j.at("images").empty() && j.at("images").is_array()) {
 | 
			
		||||
			for (const auto& i_j : j.at("images").items()) {
 | 
			
		||||
				// decode data (base64)
 | 
			
		||||
				std::vector<uint8_t> png_data(data.size); // just init to upper bound
 | 
			
		||||
				size_t decoded_size {0};
 | 
			
		||||
				sodium_base642bin(
 | 
			
		||||
					png_data.data(), png_data.size(),
 | 
			
		||||
					i_j.value().get<std::string>().data(), i_j.value().get<std::string>().size(),
 | 
			
		||||
					" \n\t",
 | 
			
		||||
					&decoded_size,
 | 
			
		||||
					nullptr,
 | 
			
		||||
					sodium_base64_VARIANT_ORIGINAL
 | 
			
		||||
				);
 | 
			
		||||
				png_data.resize(decoded_size);
 | 
			
		||||
 | 
			
		||||
				std::filesystem::create_directories("sdbot_img_send");
 | 
			
		||||
				//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
 | 
			
		||||
				const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
 | 
			
		||||
				const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
 | 
			
		||||
 | 
			
		||||
				std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
 | 
			
		||||
				_rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			std::cerr << "SDB json response did not contain images?\n";
 | 
			
		||||
			return false;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SDcpp_wip1_Endpoint : public SDBot::EndpointI {
 | 
			
		||||
	SDcpp_wip1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {}
 | 
			
		||||
 | 
			
		||||
	bool handleResponse(Contact4 contact, ByteSpan data) override {
 | 
			
		||||
		//std::cout << std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} << "\n";
 | 
			
		||||
 | 
			
		||||
		std::string_view data_str {reinterpret_cast<const char*>(data.ptr), data.size};
 | 
			
		||||
		auto nl_pos {std::string_view::npos};
 | 
			
		||||
		bool succ {false};
 | 
			
		||||
		do {
 | 
			
		||||
			// for each line, should be "data: <json>" or empty
 | 
			
		||||
			nl_pos = data_str.find_first_of('\n');
 | 
			
		||||
 | 
			
		||||
			// npos is also valid
 | 
			
		||||
			auto line = data_str.substr(0, nl_pos);
 | 
			
		||||
 | 
			
		||||
			// at least minimum viable
 | 
			
		||||
			if (line.size() >= std::string_view{"data: {}"}.size()) {
 | 
			
		||||
				//std::cout << "got data line!!!!!!!!!!!:\n";
 | 
			
		||||
				//std::cout << line << "\n";
 | 
			
		||||
				line.remove_prefix(6);
 | 
			
		||||
 | 
			
		||||
				const auto j = nlohmann::json::parse(
 | 
			
		||||
					line,
 | 
			
		||||
					nullptr,
 | 
			
		||||
					false
 | 
			
		||||
				);
 | 
			
		||||
 | 
			
		||||
				if (
 | 
			
		||||
					!j.empty() &&
 | 
			
		||||
					j.value("type", "notimag") == "image" &&
 | 
			
		||||
					j.contains("data") &&
 | 
			
		||||
					j.at("data").is_string()
 | 
			
		||||
				) {
 | 
			
		||||
					const auto& img_data_str = j.at("data").get<std::string>();
 | 
			
		||||
					// decode data (base64)
 | 
			
		||||
					std::vector<uint8_t> png_data(img_data_str.size()); // just init to upper bound
 | 
			
		||||
					size_t decoded_size {0};
 | 
			
		||||
					sodium_base642bin(
 | 
			
		||||
						png_data.data(), png_data.size(),
 | 
			
		||||
						img_data_str.data(), img_data_str.size(),
 | 
			
		||||
						" \n\t",
 | 
			
		||||
						&decoded_size,
 | 
			
		||||
						nullptr,
 | 
			
		||||
						sodium_base64_VARIANT_ORIGINAL
 | 
			
		||||
					);
 | 
			
		||||
					png_data.resize(decoded_size);
 | 
			
		||||
 | 
			
		||||
					std::filesystem::create_directories("sdbot_img_send");
 | 
			
		||||
					//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
 | 
			
		||||
					const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
 | 
			
		||||
					const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
 | 
			
		||||
 | 
			
		||||
					std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
 | 
			
		||||
					succ = _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if (nl_pos == std::string_view::npos || nl_pos+1 >= data_str.size()) {
 | 
			
		||||
				break;
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			data_str = data_str.substr(nl_pos+1);
 | 
			
		||||
		} while (nl_pos != std::string_view::npos);
 | 
			
		||||
 | 
			
		||||
		return succ;
 | 
			
		||||
	}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SDcpp_stduhpf_wip1_Endpoint : public SDBot::EndpointI {
 | 
			
		||||
	SDcpp_stduhpf_wip1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {}
 | 
			
		||||
 | 
			
		||||
	bool handleResponse(Contact4 contact, ByteSpan data) override {
 | 
			
		||||
		bool succ = true;
 | 
			
		||||
 | 
			
		||||
		try {
 | 
			
		||||
			// extract json result
 | 
			
		||||
			const auto j = nlohmann::json::parse(
 | 
			
		||||
				std::string_view{reinterpret_cast<const char*>(data.ptr), data.size}
 | 
			
		||||
			);
 | 
			
		||||
 | 
			
		||||
			if (!j.is_array()) {
 | 
			
		||||
				std::cerr << "SDB: json response was not an array\n";
 | 
			
		||||
				return false;
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for (const auto& j_entry : j) {
 | 
			
		||||
				if (!j_entry.is_object()) {
 | 
			
		||||
					std::cerr << "SDB warning: non object entry, skipping\n";
 | 
			
		||||
					continue;
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for each returned image
 | 
			
		||||
				// "channel": 3, // rgb?
 | 
			
		||||
				// "data": base64 encoded image file
 | 
			
		||||
				// "encoding": "png",
 | 
			
		||||
				// "height": 512,
 | 
			
		||||
				// "width": 512
 | 
			
		||||
 | 
			
		||||
				if (j_entry.contains("encoding")) {
 | 
			
		||||
					if (!j_entry["encoding"].is_string() || j_entry["encoding"] != "png") {
 | 
			
		||||
						std::cerr << "SDB warning: unknown encoding '" << j_entry["encoding"] << "'\n";
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if (!j_entry.contains("data") || !j_entry.at("data").is_string()) {
 | 
			
		||||
					std::cerr << "SDB warning: non data entry, skipping\n";
 | 
			
		||||
					continue;
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				const auto& img_data_str = j_entry.at("data").get<std::string>();
 | 
			
		||||
				// decode data (base64)
 | 
			
		||||
				std::vector<uint8_t> png_data(img_data_str.size()); // just init to upper bound
 | 
			
		||||
				size_t decoded_size {0};
 | 
			
		||||
				sodium_base642bin(
 | 
			
		||||
					png_data.data(), png_data.size(),
 | 
			
		||||
					img_data_str.data(), img_data_str.size(),
 | 
			
		||||
					" \n\t",
 | 
			
		||||
					&decoded_size,
 | 
			
		||||
					nullptr,
 | 
			
		||||
					sodium_base64_VARIANT_ORIGINAL
 | 
			
		||||
				);
 | 
			
		||||
				png_data.resize(decoded_size);
 | 
			
		||||
 | 
			
		||||
				std::filesystem::create_directories("sdbot_img_send");
 | 
			
		||||
				//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
 | 
			
		||||
				const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
 | 
			
		||||
				const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
 | 
			
		||||
 | 
			
		||||
				std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
 | 
			
		||||
				succ = succ && _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
 | 
			
		||||
			}
 | 
			
		||||
		} catch (...) {
 | 
			
		||||
			return false;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return succ;
 | 
			
		||||
	}
 | 
			
		||||
};
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
 | 
			
		||||
SDBot::SDBot(
 | 
			
		||||
	ContactStore4I& cs,
 | 
			
		||||
@@ -210,22 +24,16 @@ SDBot::SDBot(
 | 
			
		||||
	_rng.discard(3137);
 | 
			
		||||
 | 
			
		||||
	if (!_conf.has_string("SDBot", "endpoint_type")) {
 | 
			
		||||
		_conf.set("SDBot", "endpoint_type", std::string_view{"automatic1111_v1"}); // automatic11 default
 | 
			
		||||
		_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_stduhpf_wip2"}); // default
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//HACKy
 | 
			
		||||
	{ // construct endpoint
 | 
			
		||||
		const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value();
 | 
			
		||||
		if (endpoint_type == "automatic1111_v1") {
 | 
			
		||||
			_endpoint = std::make_unique<Automatic1111_v1_Endpoint>(_rmm, _rng);
 | 
			
		||||
		} else if (endpoint_type == "sdcpp_wip1") {
 | 
			
		||||
			_endpoint = std::make_unique<SDcpp_wip1_Endpoint>(_rmm, _rng);
 | 
			
		||||
		} else if (endpoint_type == "sdcpp_stduhpf_wip1") {
 | 
			
		||||
			_endpoint = std::make_unique<SDcpp_stduhpf_wip1_Endpoint>(_rmm, _rng);
 | 
			
		||||
		if (endpoint_type == "sdcpp_stduhpf_wip2") {
 | 
			
		||||
			_endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf);
 | 
			
		||||
		} else {
 | 
			
		||||
			std::cerr << "SDB error: unknown endpoint type '" << endpoint_type << "'\n";
 | 
			
		||||
			// TODO: throw?
 | 
			
		||||
			_endpoint = std::make_unique<Automatic1111_v1_Endpoint>(_rmm, _rng);
 | 
			
		||||
			throw std::runtime_error("missing endpoint type in config, cant continue!");
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -260,117 +68,61 @@ SDBot::~SDBot(void) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float SDBot::iterate(void) {
 | 
			
		||||
	if (_current_task.has_value() != _curr_future.has_value()) {
 | 
			
		||||
	if (_current_task_id.has_value() != (_current_task != nullptr)) {
 | 
			
		||||
		std::cerr << "SDB inconsistent state\n";
 | 
			
		||||
 | 
			
		||||
		if (_current_task.has_value()) {
 | 
			
		||||
			_task_map.erase(_current_task.value());
 | 
			
		||||
			_current_task = std::nullopt;
 | 
			
		||||
		if (_current_task_id.has_value()) {
 | 
			
		||||
			_task_map.erase(_current_task_id.value());
 | 
			
		||||
			_current_task_id = std::nullopt;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if (_curr_future.has_value()) {
 | 
			
		||||
			_curr_future.reset(); // might block and wait
 | 
			
		||||
		}
 | 
			
		||||
		_current_task.reset();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (!_prompt_queue.empty() && !_current_task.has_value()) { // dequeue new task
 | 
			
		||||
	if (!_prompt_queue.empty() && !_current_task_id.has_value()) { // dequeue new task
 | 
			
		||||
		const auto& [task_id, prompt] = _prompt_queue.front();
 | 
			
		||||
 | 
			
		||||
		_current_task = task_id;
 | 
			
		||||
 | 
			
		||||
		if (_cli == nullptr) {
 | 
			
		||||
			const std::string server_host {_conf.get_string("SDBot", "server_host").value()};
 | 
			
		||||
			_cli = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value());
 | 
			
		||||
			_cli->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nlohmann::json j_body;
 | 
			
		||||
		j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
 | 
			
		||||
		j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
 | 
			
		||||
 | 
			
		||||
		//j_body["prompt"] = prompt;
 | 
			
		||||
		//"<lora:lcm-lora-sdv1-5:1>"
 | 
			
		||||
		j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + prompt;
 | 
			
		||||
		// TODO: negative prompt
 | 
			
		||||
 | 
			
		||||
		j_body["seed"] = -1;
 | 
			
		||||
#if 0
 | 
			
		||||
		j_body["steps"] = 20;
 | 
			
		||||
		//j_body["steps"] = 5;
 | 
			
		||||
		j_body["cfg_scale"] = 6.5;
 | 
			
		||||
		j_body["sampler_index"] = "Euler a";
 | 
			
		||||
#else
 | 
			
		||||
		//j_body["steps"] = 4;
 | 
			
		||||
		j_body["steps"] = _conf.get_int("SDBot", "steps").value_or(20);
 | 
			
		||||
		//j_body["cfg_scale"] = 1;
 | 
			
		||||
		j_body["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5);
 | 
			
		||||
		//j_body["sampler_index"] = "LCM Test";
 | 
			
		||||
		j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
		j_body["batch_size"] = 1;
 | 
			
		||||
		j_body["n_iter"] = 1;
 | 
			
		||||
		j_body["restore_faces"] = false;
 | 
			
		||||
		j_body["tiling"] = false;
 | 
			
		||||
		j_body["enable_hr"] = false;
 | 
			
		||||
 | 
			
		||||
		std::string body = j_body.dump();
 | 
			
		||||
 | 
			
		||||
		try {
 | 
			
		||||
			const std::string url {_conf.get_string("SDBot", "url_txt2img").value()};
 | 
			
		||||
			_curr_future = std::async(std::launch::async, [url, body, cli = _cli]() -> std::vector<uint8_t> {
 | 
			
		||||
				if (!static_cast<bool>(cli)) {
 | 
			
		||||
					return {};
 | 
			
		||||
				}
 | 
			
		||||
				// TODO: move to endpoint
 | 
			
		||||
				auto res = cli->Post(url, body, "application/json");
 | 
			
		||||
				if (!static_cast<bool>(res)) {
 | 
			
		||||
					std::cerr << "SDB error: post to sd server failed!\n";
 | 
			
		||||
					return {};
 | 
			
		||||
				}
 | 
			
		||||
				std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n";
 | 
			
		||||
				if (
 | 
			
		||||
					res.error() != httplib::Error::Success ||
 | 
			
		||||
					res->status != 200
 | 
			
		||||
				) {
 | 
			
		||||
					return {};
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return std::vector<uint8_t>(res->body.cbegin(), res->body.cend());
 | 
			
		||||
			});
 | 
			
		||||
		} catch (...) {
 | 
			
		||||
			std::cerr << "SDB http request error\n";
 | 
			
		||||
			// cleanup
 | 
			
		||||
			_task_map.erase(_current_task.value());
 | 
			
		||||
			_current_task = std::nullopt;
 | 
			
		||||
			_curr_future.reset(); // might block and wait
 | 
			
		||||
		_current_task = _endpoint->txt2img(
 | 
			
		||||
			prompt,
 | 
			
		||||
			_conf.get_int("SDBot", "width").value_or(512),
 | 
			
		||||
			_conf.get_int("SDBot", "height").value_or(512)
 | 
			
		||||
		);
 | 
			
		||||
		if (_current_task == nullptr) {
 | 
			
		||||
			// ERROR
 | 
			
		||||
			std::cerr << "SDB error: call to txt2img failed!\n";
 | 
			
		||||
			_task_map.erase(task_id);
 | 
			
		||||
		} else {
 | 
			
		||||
			_current_task_id = task_id;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_prompt_queue.pop();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
	if (_curr_future.has_value() && _curr_future.value().wait_for(std::chrono::milliseconds(1)) == std::future_status::ready) {
 | 
			
		||||
		const auto& contact = _task_map.at(_current_task.value());
 | 
			
		||||
	if (_current_task_id && _current_task && _current_task->ready()) {
 | 
			
		||||
 | 
			
		||||
		const auto data = _curr_future.value().get();
 | 
			
		||||
		auto res_opt = _current_task->get();
 | 
			
		||||
		if (res_opt) {
 | 
			
		||||
			const auto& contact = _task_map.at(_current_task_id.value());
 | 
			
		||||
 | 
			
		||||
		if (_endpoint->handleResponse(contact, ByteSpan{data})) {
 | 
			
		||||
			// TODO: error handling
 | 
			
		||||
		}
 | 
			
		||||
			std::filesystem::create_directories("sdbot_img_send");
 | 
			
		||||
			const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + "." + res_opt.value().file_name;
 | 
			
		||||
			const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
 | 
			
		||||
 | 
			
		||||
		_task_map.erase(_current_task.value());
 | 
			
		||||
		_current_task = std::nullopt;
 | 
			
		||||
		_curr_future.reset();
 | 
			
		||||
			std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(res_opt.value().data.data()), res_opt.value().data.size());
 | 
			
		||||
			_rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path);
 | 
			
		||||
		} else {
 | 
			
		||||
			std::cerr << "SDB error: call to txt2img failed (task returned nothing)!\n";
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
	// if active web connection, 5ms
 | 
			
		||||
	//return static_cast<bool>(_con) ? 0.005f : 1.f;
 | 
			
		||||
		_current_task_id.reset();
 | 
			
		||||
		_current_task.reset();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// if active web connection, 50ms
 | 
			
		||||
	if (_curr_future.has_value() && _curr_future.value().valid()) {
 | 
			
		||||
		return 0.05f;
 | 
			
		||||
	// if active web connection, 100ms
 | 
			
		||||
	if (_current_task_id.has_value()) {
 | 
			
		||||
		return 0.1f;
 | 
			
		||||
	} else {
 | 
			
		||||
		return 1.f;
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@
 | 
			
		||||
#include <solanaceae/message3/registry_message_model.hpp>
 | 
			
		||||
#include <solanaceae/contact/fwd.hpp>
 | 
			
		||||
 | 
			
		||||
#include <httplib.h>
 | 
			
		||||
#include "./webapi_interface.hpp"
 | 
			
		||||
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
@@ -24,31 +24,17 @@ class SDBot : public RegistryMessageModelEventI {
 | 
			
		||||
	RegistryMessageModelI::SubscriptionReference _rmm_sr;
 | 
			
		||||
	ConfigModelI& _conf;
 | 
			
		||||
 | 
			
		||||
	//TransferManager& _tm;
 | 
			
		||||
 | 
			
		||||
	//std::map<uint64_t, std::variant<ContactFriend, ContactConference, ContactGroupPeer>> _task_map;
 | 
			
		||||
	std::map<uint64_t, Contact4> _task_map;
 | 
			
		||||
	std::queue<std::pair<uint64_t, std::string>> _prompt_queue;
 | 
			
		||||
	uint64_t _last_task_counter = 0;
 | 
			
		||||
 | 
			
		||||
	std::optional<uint64_t> _current_task;
 | 
			
		||||
	std::shared_ptr<httplib::Client> _cli;
 | 
			
		||||
	std::optional<std::future<std::vector<uint8_t>>> _curr_future;
 | 
			
		||||
	std::optional<uint64_t> _current_task_id;
 | 
			
		||||
	std::shared_ptr<WebAPITaskI> _current_task;
 | 
			
		||||
 | 
			
		||||
	std::default_random_engine _rng;
 | 
			
		||||
 | 
			
		||||
	public:
 | 
			
		||||
		struct EndpointI {
 | 
			
		||||
			RegistryMessageModelI& _rmm;
 | 
			
		||||
			std::default_random_engine& _rng;
 | 
			
		||||
			EndpointI(RegistryMessageModelI& rmm, std::default_random_engine& rng) : _rmm(rmm), _rng(rng) {}
 | 
			
		||||
			virtual ~EndpointI(void) {}
 | 
			
		||||
 | 
			
		||||
			virtual bool handleResponse(Contact4 contact, ByteSpan data) = 0;
 | 
			
		||||
		};
 | 
			
		||||
 | 
			
		||||
	private:
 | 
			
		||||
		std::unique_ptr<EndpointI> _endpoint;
 | 
			
		||||
		std::unique_ptr<WebAPII> _endpoint;
 | 
			
		||||
 | 
			
		||||
	public:
 | 
			
		||||
		SDBot(
 | 
			
		||||
@@ -64,9 +50,6 @@ class SDBot : public RegistryMessageModelEventI {
 | 
			
		||||
		bool use_webp_for_friends = true;
 | 
			
		||||
		bool use_webp_for_groups = true;
 | 
			
		||||
 | 
			
		||||
	//protected: // tox events
 | 
			
		||||
		//bool onToxEvent(const Tox_Event_Friend_Message* e) override;
 | 
			
		||||
		//bool onToxEvent(const Tox_Event_Group_Message* e) override;
 | 
			
		||||
	protected: // mm
 | 
			
		||||
		bool onEvent(const Message::Events::MessageConstruct& e) override;
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										40
									
								
								src/webapi_interface.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/webapi_interface.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
struct WebAPITaskI;
 | 
			
		||||
 | 
			
		||||
struct WebAPII {
 | 
			
		||||
	virtual ~WebAPII(void) {}
 | 
			
		||||
 | 
			
		||||
	virtual std::shared_ptr<WebAPITaskI> txt2img(
 | 
			
		||||
		std::string_view prompt,
 | 
			
		||||
		int16_t width,
 | 
			
		||||
		int16_t height
 | 
			
		||||
		// more
 | 
			
		||||
	) = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// only knows of a single prompt
 | 
			
		||||
struct WebAPITaskI {
 | 
			
		||||
	virtual ~WebAPITaskI(void) {}
 | 
			
		||||
 | 
			
		||||
	// true if done or failed
 | 
			
		||||
	virtual bool ready(void) = 0;
 | 
			
		||||
 | 
			
		||||
	struct Result {
 | 
			
		||||
		std::vector<uint8_t> data;
 | 
			
		||||
		int16_t width {};
 | 
			
		||||
		int16_t height {};
 | 
			
		||||
 | 
			
		||||
		std::string file_name;
 | 
			
		||||
	};
 | 
			
		||||
 | 
			
		||||
	virtual std::optional<Result> get(void) = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										228
									
								
								src/webapi_sdcpp_stduhpf_wip2.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								src/webapi_sdcpp_stduhpf_wip2.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,228 @@
 | 
			
		||||
#include "./webapi_sdcpp_stduhpf_wip2.hpp"
 | 
			
		||||
 | 
			
		||||
#include <httplib.h>
 | 
			
		||||
#include <nlohmann/json.hpp>
 | 
			
		||||
#include <sodium.h>
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) {
 | 
			
		||||
	if (_cl == nullptr) {
 | 
			
		||||
		const std::string server_host {_conf.get_string("SDBot", "server_host").value()};
 | 
			
		||||
		_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value());
 | 
			
		||||
		_cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return _cl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) :
 | 
			
		||||
	_conf(conf)
 | 
			
		||||
{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img(
 | 
			
		||||
	std::string_view prompt,
 | 
			
		||||
	int16_t width,
 | 
			
		||||
	int16_t height
 | 
			
		||||
	// more
 | 
			
		||||
) {
 | 
			
		||||
	auto cl = getCl();
 | 
			
		||||
	if (!cl) {
 | 
			
		||||
		return nullptr;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nlohmann::json j_body;
 | 
			
		||||
	// eg
 | 
			
		||||
	//{
 | 
			
		||||
	//  "prompt": "a lovely schnitzel",
 | 
			
		||||
	//  "negative_prompt": "",
 | 
			
		||||
	//  "width": 512,
 | 
			
		||||
	//  "height": 512,
 | 
			
		||||
	//  "guidance_params": {
 | 
			
		||||
	//    "cfg_scale": 1,
 | 
			
		||||
	//    "guidance": 3.5
 | 
			
		||||
	//  },
 | 
			
		||||
	//  "sample_steps": 8,
 | 
			
		||||
	//  "sample_method": "euler_a",
 | 
			
		||||
	//  "seed": -1,
 | 
			
		||||
	//  "batch_count": 1,
 | 
			
		||||
	//  "schedule": "default",
 | 
			
		||||
	//  "model": -1,
 | 
			
		||||
	//  "diffusion_model": -1,
 | 
			
		||||
	//  "clip_l": -1,
 | 
			
		||||
	//  "clip_g": -1,
 | 
			
		||||
	//  "t5xxl": -1,
 | 
			
		||||
	//  "vae": -1,
 | 
			
		||||
	//  "tae": -1,
 | 
			
		||||
	//  "keep_vae_on_cpu": false,
 | 
			
		||||
	//  "kep_clip_on_cpu": false,
 | 
			
		||||
	//  "vae_tiling": true,
 | 
			
		||||
	//  "tae_decode": true
 | 
			
		||||
	//}
 | 
			
		||||
	//j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
 | 
			
		||||
	j_body["width"] = width;
 | 
			
		||||
	//j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
 | 
			
		||||
	j_body["height"] = height;
 | 
			
		||||
 | 
			
		||||
	j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt};
 | 
			
		||||
	// TODO: negative prompt
 | 
			
		||||
 | 
			
		||||
	j_body["seed"] = -1;
 | 
			
		||||
	j_body["sample_steps"] = _conf.get_int("SDBot", "steps").value_or(20);
 | 
			
		||||
	j_body["guidance_params"]["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5);
 | 
			
		||||
	//j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")};
 | 
			
		||||
 | 
			
		||||
	std::string body = j_body.dump();
 | 
			
		||||
 | 
			
		||||
	const std::string url {_conf.get_string("SDBot", "url_txt2img").value_or("/txt2img")};
 | 
			
		||||
 | 
			
		||||
	try {
 | 
			
		||||
		// not restful -> returns imediatly with task id
 | 
			
		||||
		auto res = cl->Post(url, body, "application/json");
 | 
			
		||||
		if (!static_cast<bool>(res)) {
 | 
			
		||||
			std::cerr << "SDB error: post to sd server failed!\n";
 | 
			
		||||
			return nullptr;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n";
 | 
			
		||||
 | 
			
		||||
		if (
 | 
			
		||||
			res.error() != httplib::Error::Success ||
 | 
			
		||||
			res->status != 200
 | 
			
		||||
		) {
 | 
			
		||||
			return nullptr;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		try {
 | 
			
		||||
			// {"task_id":"1753652554405483588"}
 | 
			
		||||
			auto j_res = nlohmann::json::parse(res->body);
 | 
			
		||||
 | 
			
		||||
			if (!j_res.contains("task_id")) {
 | 
			
		||||
				std::cerr << "SDB error: response not a task_id\n";
 | 
			
		||||
				return nullptr;
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			uint64_t task_id{};
 | 
			
		||||
			const auto& j_task_id = j_res.at("task_id");
 | 
			
		||||
			if (!j_task_id.is_number_unsigned()) {
 | 
			
		||||
				// meh, conversion time
 | 
			
		||||
				task_id = std::stoull(j_task_id.get<std::string>().c_str());
 | 
			
		||||
				//std::cout << "converted " << j_task_id << " to " << task_id << "\n";
 | 
			
		||||
			} else {
 | 
			
		||||
				task_id = j_task_id.get<uint64_t>();
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			std::cout << "SDB: sdcpp task id: " << task_id << "\n";
 | 
			
		||||
 | 
			
		||||
			return std::make_shared<WebAPITask_sdcpp_stduhpf_wip2>(
 | 
			
		||||
				cl,
 | 
			
		||||
				"/result", // TODO: from conf?
 | 
			
		||||
				task_id
 | 
			
		||||
			);
 | 
			
		||||
		} catch (...) {
 | 
			
		||||
			std::cerr << "SDB error: failed parsing response as json\n";
 | 
			
		||||
			return nullptr;
 | 
			
		||||
		}
 | 
			
		||||
	} catch (...) {
 | 
			
		||||
		std::cerr << "SDB http request error\n";
 | 
			
		||||
		return nullptr;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ???
 | 
			
		||||
	return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
WebAPITask_sdcpp_stduhpf_wip2::WebAPITask_sdcpp_stduhpf_wip2(std::shared_ptr<httplib::Client> cl, const std::string& url, uint64_t task_id)
 | 
			
		||||
	: _cl(cl), _url(url), _task_id(task_id)
 | 
			
		||||
{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//WebAPITask_sdcpp_stduhpf_wip2::~WebAPITask_sdcpp_stduhpf_wip2(void) {
 | 
			
		||||
//}
 | 
			
		||||
 | 
			
		||||
bool WebAPITask_sdcpp_stduhpf_wip2::ready(void) {
 | 
			
		||||
	// polling api
 | 
			
		||||
 | 
			
		||||
	try {
 | 
			
		||||
		auto res = _cl->Get(_url + "?task_id=" + std::to_string(_task_id));
 | 
			
		||||
		if (!static_cast<bool>(res)) {
 | 
			
		||||
			std::cerr << "SDB error: post to sd server failed! (in task)\n";
 | 
			
		||||
			_done = true;
 | 
			
		||||
			return true;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if (
 | 
			
		||||
			res.error() != httplib::Error::Success ||
 | 
			
		||||
			res->status != 200
 | 
			
		||||
		) {
 | 
			
		||||
			std::cerr << "SDB error: post to sd server failed! (in task)\n";
 | 
			
		||||
			_done = true;
 | 
			
		||||
			return true;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		auto j_res = nlohmann::json::parse(res->body);
 | 
			
		||||
		if (j_res.at("status") != "Done") {
 | 
			
		||||
			// not ready (likely path)
 | 
			
		||||
			return false;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// TODO: add support for multiple images
 | 
			
		||||
		auto& j_image = j_res.at("data").at(0); // ??
 | 
			
		||||
 | 
			
		||||
		_result.width = j_image.at("width");
 | 
			
		||||
		_result.height = j_image.at("height");
 | 
			
		||||
 | 
			
		||||
		{ // data
 | 
			
		||||
			auto& j_data = j_image.at("data");
 | 
			
		||||
 | 
			
		||||
			_result.data.resize(j_data.get<std::string>().size()); // HACK: do a better estimate
 | 
			
		||||
			size_t decoded_size {0};
 | 
			
		||||
			sodium_base642bin(
 | 
			
		||||
				_result.data.data(), _result.data.size(),
 | 
			
		||||
				j_data.get<std::string>().data(), j_data.get<std::string>().size(),
 | 
			
		||||
				" \n\t",
 | 
			
		||||
				&decoded_size,
 | 
			
		||||
				nullptr,
 | 
			
		||||
				sodium_base64_VARIANT_ORIGINAL
 | 
			
		||||
			);
 | 
			
		||||
			_result.data.resize(decoded_size);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		{ // file name?
 | 
			
		||||
			// == "png" or jpeg or somehting
 | 
			
		||||
			_result.file_name = std::string{"output_0."} + j_image.value("encoding", "png");
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_done = true;
 | 
			
		||||
		return true;
 | 
			
		||||
	} catch (...) {
 | 
			
		||||
		// rip
 | 
			
		||||
		_done = true;
 | 
			
		||||
		return true;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<WebAPITaskI::Result> WebAPITask_sdcpp_stduhpf_wip2::get(void) {
 | 
			
		||||
	if (!_done) {
 | 
			
		||||
		assert(false && "what ya doin?");
 | 
			
		||||
		return std::nullopt;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (_result.data.empty()) {
 | 
			
		||||
		return std::nullopt;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return std::move(_result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										52
									
								
								src/webapi_sdcpp_stduhpf_wip2.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								src/webapi_sdcpp_stduhpf_wip2.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "./webapi_interface.hpp"
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <solanaceae/util/config_model.hpp>
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
// fwd
 | 
			
		||||
namespace httplib {
 | 
			
		||||
	class Client;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// in its second api iteration, stduhpf switched away from a rest api
 | 
			
		||||
struct WebAPI_sdcpp_stduhpf_wip2 : public WebAPII {
 | 
			
		||||
	// TODO: const config
 | 
			
		||||
	ConfigModelI& _conf;
 | 
			
		||||
 | 
			
		||||
	std::shared_ptr<httplib::Client> _cl;
 | 
			
		||||
	std::shared_ptr<httplib::Client> getCl(void);
 | 
			
		||||
 | 
			
		||||
	WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf);
 | 
			
		||||
 | 
			
		||||
	~WebAPI_sdcpp_stduhpf_wip2(void) override;
 | 
			
		||||
 | 
			
		||||
	std::shared_ptr<WebAPITaskI> txt2img(
 | 
			
		||||
		std::string_view prompt,
 | 
			
		||||
		int16_t width,
 | 
			
		||||
		int16_t height
 | 
			
		||||
		// more
 | 
			
		||||
	) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// discard after get() !!
 | 
			
		||||
struct WebAPITask_sdcpp_stduhpf_wip2 : public WebAPITaskI {
 | 
			
		||||
	std::shared_ptr<httplib::Client> _cl;
 | 
			
		||||
	const std::string _url;
 | 
			
		||||
 | 
			
		||||
	uint64_t _task_id {};
 | 
			
		||||
	bool _done {false};
 | 
			
		||||
	Result _result;
 | 
			
		||||
 | 
			
		||||
	WebAPITask_sdcpp_stduhpf_wip2(std::shared_ptr<httplib::Client> cl, const std::string& url, uint64_t task_id);
 | 
			
		||||
 | 
			
		||||
	~WebAPITask_sdcpp_stduhpf_wip2(void) override {}
 | 
			
		||||
 | 
			
		||||
	bool ready(void) override;
 | 
			
		||||
	std::optional<Result> get(void) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user