Compare commits
	
		
			3 Commits
		
	
	
		
			b3315da1d9
			...
			2beb74eb5f
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 2beb74eb5f | ||
|  | bfd923e829 | ||
|  | 24ae710c29 | 
							
								
								
									
										4
									
								
								.github/workflows/cd.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/cd.yml
									
									
									
									
										vendored
									
									
								
							| @@ -14,7 +14,7 @@ jobs: | |||||||
|   linux-ubuntu: |   linux-ubuntu: | ||||||
|     timeout-minutes: 10 |     timeout-minutes: 10 | ||||||
|  |  | ||||||
|     runs-on: ubuntu-20.04 |     runs-on: ubuntu-24.04 | ||||||
|  |  | ||||||
|     steps: |     steps: | ||||||
|     - uses: actions/checkout@v4 |     - uses: actions/checkout@v4 | ||||||
| @@ -32,7 +32,7 @@ jobs: | |||||||
|  |  | ||||||
|     - uses: actions/upload-artifact@v4 |     - uses: actions/upload-artifact@v4 | ||||||
|       with: |       with: | ||||||
|         name: ${{ github.event.repository.name }}-ubuntu20.04-x86_64 |         name: ${{ github.event.repository.name }}-ubuntu24.04-x86_64 | ||||||
|         # TODO: do propper packing |         # TODO: do propper packing | ||||||
|         path: | |         path: | | ||||||
|           ${{github.workspace}}/build/bin/ |           ${{github.workspace}}/build/bin/ | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| ## Solanaceae extention and plugin to serve StableDiffusion | ## Solanaceae extention and plugin to serve StableDiffusion | ||||||
|  |  | ||||||
| !! currently only works with automatic1111's api !! | !! currently only works with [stduhpf's stablediffusion.cpp http server api](https://github.com/leejet/stable-diffusion.cpp/pull/367)  !! | ||||||
|  |  | ||||||
| ### example config for `totato` | ### example config for `totato` | ||||||
| ```json | ```json | ||||||
| @@ -15,15 +15,15 @@ | |||||||
| 	"SDBot": { | 	"SDBot": { | ||||||
| 		"server_host": "127.0.0.1", | 		"server_host": "127.0.0.1", | ||||||
| 		"server_port": 8080, | 		"server_port": 8080, | ||||||
| 		"url_txt2img": "/sdapi/v1/txt2img", | 		"endpoint_type": "sdcpp_stduhpf_wip2", | ||||||
|  | 		"url_txt2img": "/txt2img", | ||||||
|  |  | ||||||
| 		"width": 512, | 		"width": 512, | ||||||
| 		"height": 512, | 		"height": 512, | ||||||
|  |  | ||||||
| 		"prompt_prefix": "<lora:lcm-lora-sdv1-5:1>", | 		"prompt_prefix": "<lora:lcm-lora-sdv1-5:1>", | ||||||
| 		"steps": 8, | 		"steps": 8, | ||||||
| 		"cfg_scale": 1.0, | 		"cfg_scale": 1.0 | ||||||
| 		"sampler_index": "LCM Test" |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| ``` | ``` | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								external/CMakeLists.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								external/CMakeLists.txt
									
									
									
									
										vendored
									
									
								
							| @@ -78,7 +78,7 @@ endif() | |||||||
| if (NOT TARGET httplib::httplib) | if (NOT TARGET httplib::httplib) | ||||||
| 	FetchContent_Declare(httplib | 	FetchContent_Declare(httplib | ||||||
| 		GIT_REPOSITORY https://github.com/yhirose/cpp-httplib.git | 		GIT_REPOSITORY https://github.com/yhirose/cpp-httplib.git | ||||||
| 		GIT_TAG v0.19.0 | 		GIT_TAG v0.22.0 | ||||||
| 		EXCLUDE_FROM_ALL | 		EXCLUDE_FROM_ALL | ||||||
| 	) | 	) | ||||||
| 	FetchContent_MakeAvailable(httplib) | 	FetchContent_MakeAvailable(httplib) | ||||||
|   | |||||||
| @@ -1,6 +1,10 @@ | |||||||
| cmake_minimum_required(VERSION 3.9 FATAL_ERROR) | cmake_minimum_required(VERSION 3.9 FATAL_ERROR) | ||||||
|  |  | ||||||
| add_library(solanaceae_sdbot-webui STATIC | add_library(solanaceae_sdbot-webui STATIC | ||||||
|  | 	./webapi_interface.hpp | ||||||
|  | 	./webapi_sdcpp_stduhpf_wip2.hpp | ||||||
|  | 	./webapi_sdcpp_stduhpf_wip2.cpp | ||||||
|  |  | ||||||
| 	./sd_bot.hpp | 	./sd_bot.hpp | ||||||
| 	./sd_bot.cpp | 	./sd_bot.cpp | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										330
									
								
								src/sd_bot.cpp
									
									
									
									
									
								
							
							
						
						
									
										330
									
								
								src/sd_bot.cpp
									
									
									
									
									
								
							| @@ -6,200 +6,14 @@ | |||||||
| #include <solanaceae/contact/components.hpp> | #include <solanaceae/contact/components.hpp> | ||||||
| #include <solanaceae/message3/components.hpp> | #include <solanaceae/message3/components.hpp> | ||||||
|  |  | ||||||
| #include <nlohmann/json.hpp> | #include "./webapi_sdcpp_stduhpf_wip2.hpp" | ||||||
| #include <sodium.h> |  | ||||||
|  |  | ||||||
| #include <fstream> | #include <fstream> | ||||||
| #include <filesystem> | #include <filesystem> | ||||||
| #include <chrono> | #include <chrono> | ||||||
|  |  | ||||||
| #include <iostream> | #include <iostream> | ||||||
|  | #include <stdexcept> | ||||||
| struct Automatic1111_v1_Endpoint : public SDBot::EndpointI { |  | ||||||
| 	Automatic1111_v1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {} |  | ||||||
|  |  | ||||||
| 	bool handleResponse(Contact4 contact, ByteSpan data) override { |  | ||||||
| 		//std::cout << std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} << "\n"; |  | ||||||
|  |  | ||||||
| 		// extract json result |  | ||||||
| 		const auto j = nlohmann::json::parse( |  | ||||||
| 			std::string_view{reinterpret_cast<const char*>(data.ptr), data.size}, |  | ||||||
| 			nullptr, |  | ||||||
| 			false |  | ||||||
| 		); |  | ||||||
| 		//std::cout << "json dump: " << j.dump() << "\n"; |  | ||||||
|  |  | ||||||
| 		if (j.count("images") && !j.at("images").empty() && j.at("images").is_array()) { |  | ||||||
| 			for (const auto& i_j : j.at("images").items()) { |  | ||||||
| 				// decode data (base64) |  | ||||||
| 				std::vector<uint8_t> png_data(data.size); // just init to upper bound |  | ||||||
| 				size_t decoded_size {0}; |  | ||||||
| 				sodium_base642bin( |  | ||||||
| 					png_data.data(), png_data.size(), |  | ||||||
| 					i_j.value().get<std::string>().data(), i_j.value().get<std::string>().size(), |  | ||||||
| 					" \n\t", |  | ||||||
| 					&decoded_size, |  | ||||||
| 					nullptr, |  | ||||||
| 					sodium_base64_VARIANT_ORIGINAL |  | ||||||
| 				); |  | ||||||
| 				png_data.resize(decoded_size); |  | ||||||
|  |  | ||||||
| 				std::filesystem::create_directories("sdbot_img_send"); |  | ||||||
| 				//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; |  | ||||||
| 				const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; |  | ||||||
| 				const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; |  | ||||||
|  |  | ||||||
| 				std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size()); |  | ||||||
| 				_rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			std::cerr << "SDB json response did not contain images?\n"; |  | ||||||
| 			return false; |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return true; |  | ||||||
| 	} |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| struct SDcpp_wip1_Endpoint : public SDBot::EndpointI { |  | ||||||
| 	SDcpp_wip1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {} |  | ||||||
|  |  | ||||||
| 	bool handleResponse(Contact4 contact, ByteSpan data) override { |  | ||||||
| 		//std::cout << std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} << "\n"; |  | ||||||
|  |  | ||||||
| 		std::string_view data_str {reinterpret_cast<const char*>(data.ptr), data.size}; |  | ||||||
| 		auto nl_pos {std::string_view::npos}; |  | ||||||
| 		bool succ {false}; |  | ||||||
| 		do { |  | ||||||
| 			// for each line, should be "data: <json>" or empty |  | ||||||
| 			nl_pos = data_str.find_first_of('\n'); |  | ||||||
|  |  | ||||||
| 			// npos is also valid |  | ||||||
| 			auto line = data_str.substr(0, nl_pos); |  | ||||||
|  |  | ||||||
| 			// at least minimum viable |  | ||||||
| 			if (line.size() >= std::string_view{"data: {}"}.size()) { |  | ||||||
| 				//std::cout << "got data line!!!!!!!!!!!:\n"; |  | ||||||
| 				//std::cout << line << "\n"; |  | ||||||
| 				line.remove_prefix(6); |  | ||||||
|  |  | ||||||
| 				const auto j = nlohmann::json::parse( |  | ||||||
| 					line, |  | ||||||
| 					nullptr, |  | ||||||
| 					false |  | ||||||
| 				); |  | ||||||
|  |  | ||||||
| 				if ( |  | ||||||
| 					!j.empty() && |  | ||||||
| 					j.value("type", "notimag") == "image" && |  | ||||||
| 					j.contains("data") && |  | ||||||
| 					j.at("data").is_string() |  | ||||||
| 				) { |  | ||||||
| 					const auto& img_data_str = j.at("data").get<std::string>(); |  | ||||||
| 					// decode data (base64) |  | ||||||
| 					std::vector<uint8_t> png_data(img_data_str.size()); // just init to upper bound |  | ||||||
| 					size_t decoded_size {0}; |  | ||||||
| 					sodium_base642bin( |  | ||||||
| 						png_data.data(), png_data.size(), |  | ||||||
| 						img_data_str.data(), img_data_str.size(), |  | ||||||
| 						" \n\t", |  | ||||||
| 						&decoded_size, |  | ||||||
| 						nullptr, |  | ||||||
| 						sodium_base64_VARIANT_ORIGINAL |  | ||||||
| 					); |  | ||||||
| 					png_data.resize(decoded_size); |  | ||||||
|  |  | ||||||
| 					std::filesystem::create_directories("sdbot_img_send"); |  | ||||||
| 					//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; |  | ||||||
| 					const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; |  | ||||||
| 					const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; |  | ||||||
|  |  | ||||||
| 					std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size()); |  | ||||||
| 					succ = _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if (nl_pos == std::string_view::npos || nl_pos+1 >= data_str.size()) { |  | ||||||
| 				break; |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			data_str = data_str.substr(nl_pos+1); |  | ||||||
| 		} while (nl_pos != std::string_view::npos); |  | ||||||
|  |  | ||||||
| 		return succ; |  | ||||||
| 	} |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| struct SDcpp_stduhpf_wip1_Endpoint : public SDBot::EndpointI { |  | ||||||
| 	SDcpp_stduhpf_wip1_Endpoint(RegistryMessageModelI& rmm, std::default_random_engine& rng) : SDBot::EndpointI(rmm, rng) {} |  | ||||||
|  |  | ||||||
| 	bool handleResponse(Contact4 contact, ByteSpan data) override { |  | ||||||
| 		bool succ = true; |  | ||||||
|  |  | ||||||
| 		try { |  | ||||||
| 			// extract json result |  | ||||||
| 			const auto j = nlohmann::json::parse( |  | ||||||
| 				std::string_view{reinterpret_cast<const char*>(data.ptr), data.size} |  | ||||||
| 			); |  | ||||||
|  |  | ||||||
| 			if (!j.is_array()) { |  | ||||||
| 				std::cerr << "SDB: json response was not an array\n"; |  | ||||||
| 				return false; |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			for (const auto& j_entry : j) { |  | ||||||
| 				if (!j_entry.is_object()) { |  | ||||||
| 					std::cerr << "SDB warning: non object entry, skipping\n"; |  | ||||||
| 					continue; |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// for each returned image |  | ||||||
| 				// "channel": 3, // rgb? |  | ||||||
| 				// "data": base64 encoded image file |  | ||||||
| 				// "encoding": "png", |  | ||||||
| 				// "height": 512, |  | ||||||
| 				// "width": 512 |  | ||||||
|  |  | ||||||
| 				if (j_entry.contains("encoding")) { |  | ||||||
| 					if (!j_entry["encoding"].is_string() || j_entry["encoding"] != "png") { |  | ||||||
| 						std::cerr << "SDB warning: unknown encoding '" << j_entry["encoding"] << "'\n"; |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				if (!j_entry.contains("data") || !j_entry.at("data").is_string()) { |  | ||||||
| 					std::cerr << "SDB warning: non data entry, skipping\n"; |  | ||||||
| 					continue; |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				const auto& img_data_str = j_entry.at("data").get<std::string>(); |  | ||||||
| 				// decode data (base64) |  | ||||||
| 				std::vector<uint8_t> png_data(img_data_str.size()); // just init to upper bound |  | ||||||
| 				size_t decoded_size {0}; |  | ||||||
| 				sodium_base642bin( |  | ||||||
| 					png_data.data(), png_data.size(), |  | ||||||
| 					img_data_str.data(), img_data_str.size(), |  | ||||||
| 					" \n\t", |  | ||||||
| 					&decoded_size, |  | ||||||
| 					nullptr, |  | ||||||
| 					sodium_base64_VARIANT_ORIGINAL |  | ||||||
| 				); |  | ||||||
| 				png_data.resize(decoded_size); |  | ||||||
|  |  | ||||||
| 				std::filesystem::create_directories("sdbot_img_send"); |  | ||||||
| 				//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; |  | ||||||
| 				const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; |  | ||||||
| 				const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; |  | ||||||
|  |  | ||||||
| 				std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size()); |  | ||||||
| 				succ = succ && _rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); |  | ||||||
| 			} |  | ||||||
| 		} catch (...) { |  | ||||||
| 			return false; |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return succ; |  | ||||||
| 	} |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| SDBot::SDBot( | SDBot::SDBot( | ||||||
| 	ContactStore4I& cs, | 	ContactStore4I& cs, | ||||||
| @@ -210,22 +24,16 @@ SDBot::SDBot( | |||||||
| 	_rng.discard(3137); | 	_rng.discard(3137); | ||||||
|  |  | ||||||
| 	if (!_conf.has_string("SDBot", "endpoint_type")) { | 	if (!_conf.has_string("SDBot", "endpoint_type")) { | ||||||
| 		_conf.set("SDBot", "endpoint_type", std::string_view{"automatic1111_v1"}); // automatic11 default | 		_conf.set("SDBot", "endpoint_type", std::string_view{"sdcpp_stduhpf_wip2"}); // default | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	//HACKy | 	//HACKy | ||||||
| 	{ // construct endpoint | 	{ // construct endpoint | ||||||
| 		const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value(); | 		const std::string_view endpoint_type = _conf.get_string("SDBot", "endpoint_type").value(); | ||||||
| 		if (endpoint_type == "automatic1111_v1") { | 		if (endpoint_type == "sdcpp_stduhpf_wip2") { | ||||||
| 			_endpoint = std::make_unique<Automatic1111_v1_Endpoint>(_rmm, _rng); | 			_endpoint = std::make_unique<WebAPI_sdcpp_stduhpf_wip2>(_conf); | ||||||
| 		} else if (endpoint_type == "sdcpp_wip1") { |  | ||||||
| 			_endpoint = std::make_unique<SDcpp_wip1_Endpoint>(_rmm, _rng); |  | ||||||
| 		} else if (endpoint_type == "sdcpp_stduhpf_wip1") { |  | ||||||
| 			_endpoint = std::make_unique<SDcpp_stduhpf_wip1_Endpoint>(_rmm, _rng); |  | ||||||
| 		} else { | 		} else { | ||||||
| 			std::cerr << "SDB error: unknown endpoint type '" << endpoint_type << "'\n"; | 			throw std::runtime_error("missing endpoint type in config, cant continue!"); | ||||||
| 			// TODO: throw? |  | ||||||
| 			_endpoint = std::make_unique<Automatic1111_v1_Endpoint>(_rmm, _rng); |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -260,117 +68,61 @@ SDBot::~SDBot(void) { | |||||||
| } | } | ||||||
|  |  | ||||||
| float SDBot::iterate(void) { | float SDBot::iterate(void) { | ||||||
| 	if (_current_task.has_value() != _curr_future.has_value()) { | 	if (_current_task_id.has_value() != (_current_task != nullptr)) { | ||||||
| 		std::cerr << "SDB inconsistent state\n"; | 		std::cerr << "SDB inconsistent state\n"; | ||||||
|  |  | ||||||
| 		if (_current_task.has_value()) { | 		if (_current_task_id.has_value()) { | ||||||
| 			_task_map.erase(_current_task.value()); | 			_task_map.erase(_current_task_id.value()); | ||||||
| 			_current_task = std::nullopt; | 			_current_task_id = std::nullopt; | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if (_curr_future.has_value()) { | 		_current_task.reset(); | ||||||
| 			_curr_future.reset(); // might block and wait |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if (!_prompt_queue.empty() && !_current_task.has_value()) { // dequeue new task | 	if (!_prompt_queue.empty() && !_current_task_id.has_value()) { // dequeue new task | ||||||
| 		const auto& [task_id, prompt] = _prompt_queue.front(); | 		const auto& [task_id, prompt] = _prompt_queue.front(); | ||||||
|  |  | ||||||
| 		_current_task = task_id; | 		_current_task = _endpoint->txt2img( | ||||||
|  | 			prompt, | ||||||
| 		if (_cli == nullptr) { | 			_conf.get_int("SDBot", "width").value_or(512), | ||||||
| 			const std::string server_host {_conf.get_string("SDBot", "server_host").value()}; | 			_conf.get_int("SDBot", "height").value_or(512) | ||||||
| 			_cli = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value()); | 		); | ||||||
| 			_cli->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while | 		if (_current_task == nullptr) { | ||||||
| 		} | 			// ERROR | ||||||
|  | 			std::cerr << "SDB error: call to txt2img failed!\n"; | ||||||
| 		nlohmann::json j_body; | 			_task_map.erase(task_id); | ||||||
| 		j_body["width"] = _conf.get_int("SDBot", "width").value_or(512); | 		} else { | ||||||
| 		j_body["height"] = _conf.get_int("SDBot", "height").value_or(512); | 			_current_task_id = task_id; | ||||||
|  |  | ||||||
| 		//j_body["prompt"] = prompt; |  | ||||||
| 		//"<lora:lcm-lora-sdv1-5:1>" |  | ||||||
| 		j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + prompt; |  | ||||||
| 		// TODO: negative prompt |  | ||||||
|  |  | ||||||
| 		j_body["seed"] = -1; |  | ||||||
| #if 0 |  | ||||||
| 		j_body["steps"] = 20; |  | ||||||
| 		//j_body["steps"] = 5; |  | ||||||
| 		j_body["cfg_scale"] = 6.5; |  | ||||||
| 		j_body["sampler_index"] = "Euler a"; |  | ||||||
| #else |  | ||||||
| 		//j_body["steps"] = 4; |  | ||||||
| 		j_body["steps"] = _conf.get_int("SDBot", "steps").value_or(20); |  | ||||||
| 		//j_body["cfg_scale"] = 1; |  | ||||||
| 		j_body["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5); |  | ||||||
| 		//j_body["sampler_index"] = "LCM Test"; |  | ||||||
| 		j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")}; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| 		j_body["batch_size"] = 1; |  | ||||||
| 		j_body["n_iter"] = 1; |  | ||||||
| 		j_body["restore_faces"] = false; |  | ||||||
| 		j_body["tiling"] = false; |  | ||||||
| 		j_body["enable_hr"] = false; |  | ||||||
|  |  | ||||||
| 		std::string body = j_body.dump(); |  | ||||||
|  |  | ||||||
| 		try { |  | ||||||
| 			const std::string url {_conf.get_string("SDBot", "url_txt2img").value()}; |  | ||||||
| 			_curr_future = std::async(std::launch::async, [url, body, cli = _cli]() -> std::vector<uint8_t> { |  | ||||||
| 				if (!static_cast<bool>(cli)) { |  | ||||||
| 					return {}; |  | ||||||
| 				} |  | ||||||
| 				// TODO: move to endpoint |  | ||||||
| 				auto res = cli->Post(url, body, "application/json"); |  | ||||||
| 				if (!static_cast<bool>(res)) { |  | ||||||
| 					std::cerr << "SDB error: post to sd server failed!\n"; |  | ||||||
| 					return {}; |  | ||||||
| 				} |  | ||||||
| 				std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n"; |  | ||||||
| 				if ( |  | ||||||
| 					res.error() != httplib::Error::Success || |  | ||||||
| 					res->status != 200 |  | ||||||
| 				) { |  | ||||||
| 					return {}; |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				return std::vector<uint8_t>(res->body.cbegin(), res->body.cend()); |  | ||||||
| 			}); |  | ||||||
| 		} catch (...) { |  | ||||||
| 			std::cerr << "SDB http request error\n"; |  | ||||||
| 			// cleanup |  | ||||||
| 			_task_map.erase(_current_task.value()); |  | ||||||
| 			_current_task = std::nullopt; |  | ||||||
| 			_curr_future.reset(); // might block and wait |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		_prompt_queue.pop(); | 		_prompt_queue.pop(); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  |  | ||||||
| 	if (_curr_future.has_value() && _curr_future.value().wait_for(std::chrono::milliseconds(1)) == std::future_status::ready) { | 	if (_current_task_id && _current_task && _current_task->ready()) { | ||||||
| 		const auto& contact = _task_map.at(_current_task.value()); |  | ||||||
|  |  | ||||||
| 		const auto data = _curr_future.value().get(); | 		auto res_opt = _current_task->get(); | ||||||
|  | 		if (res_opt) { | ||||||
|  | 			const auto& contact = _task_map.at(_current_task_id.value()); | ||||||
|  |  | ||||||
| 		if (_endpoint->handleResponse(contact, ByteSpan{data})) { | 			std::filesystem::create_directories("sdbot_img_send"); | ||||||
| 			// TODO: error handling | 			const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + "." + res_opt.value().file_name; | ||||||
|  | 			const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; | ||||||
|  |  | ||||||
|  | 			std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(res_opt.value().data.data()), res_opt.value().data.size()); | ||||||
|  | 			_rmm.sendFilePath(contact, tmp_img_file_name, tmp_img_file_path); | ||||||
|  | 		} else { | ||||||
|  | 			std::cerr << "SDB error: call to txt2img failed (task returned nothing)!\n"; | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		_task_map.erase(_current_task.value()); |  | ||||||
| 		_current_task = std::nullopt; | 		_current_task_id.reset(); | ||||||
| 		_curr_future.reset(); | 		_current_task.reset(); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// if active web connection, 100ms | ||||||
| 	// if active web connection, 5ms | 	if (_current_task_id.has_value()) { | ||||||
| 	//return static_cast<bool>(_con) ? 0.005f : 1.f; | 		return 0.1f; | ||||||
|  |  | ||||||
| 	// if active web connection, 50ms |  | ||||||
| 	if (_curr_future.has_value() && _curr_future.value().valid()) { |  | ||||||
| 		return 0.05f; |  | ||||||
| 	} else { | 	} else { | ||||||
| 		return 1.f; | 		return 1.f; | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ | |||||||
| #include <solanaceae/message3/registry_message_model.hpp> | #include <solanaceae/message3/registry_message_model.hpp> | ||||||
| #include <solanaceae/contact/fwd.hpp> | #include <solanaceae/contact/fwd.hpp> | ||||||
|  |  | ||||||
| #include <httplib.h> | #include "./webapi_interface.hpp" | ||||||
|  |  | ||||||
| #include <map> | #include <map> | ||||||
| #include <vector> | #include <vector> | ||||||
| @@ -24,31 +24,17 @@ class SDBot : public RegistryMessageModelEventI { | |||||||
| 	RegistryMessageModelI::SubscriptionReference _rmm_sr; | 	RegistryMessageModelI::SubscriptionReference _rmm_sr; | ||||||
| 	ConfigModelI& _conf; | 	ConfigModelI& _conf; | ||||||
|  |  | ||||||
| 	//TransferManager& _tm; |  | ||||||
|  |  | ||||||
| 	//std::map<uint64_t, std::variant<ContactFriend, ContactConference, ContactGroupPeer>> _task_map; |  | ||||||
| 	std::map<uint64_t, Contact4> _task_map; | 	std::map<uint64_t, Contact4> _task_map; | ||||||
| 	std::queue<std::pair<uint64_t, std::string>> _prompt_queue; | 	std::queue<std::pair<uint64_t, std::string>> _prompt_queue; | ||||||
| 	uint64_t _last_task_counter = 0; | 	uint64_t _last_task_counter = 0; | ||||||
|  |  | ||||||
| 	std::optional<uint64_t> _current_task; | 	std::optional<uint64_t> _current_task_id; | ||||||
| 	std::shared_ptr<httplib::Client> _cli; | 	std::shared_ptr<WebAPITaskI> _current_task; | ||||||
| 	std::optional<std::future<std::vector<uint8_t>>> _curr_future; |  | ||||||
|  |  | ||||||
| 	std::default_random_engine _rng; | 	std::default_random_engine _rng; | ||||||
|  |  | ||||||
| 	public: |  | ||||||
| 		struct EndpointI { |  | ||||||
| 			RegistryMessageModelI& _rmm; |  | ||||||
| 			std::default_random_engine& _rng; |  | ||||||
| 			EndpointI(RegistryMessageModelI& rmm, std::default_random_engine& rng) : _rmm(rmm), _rng(rng) {} |  | ||||||
| 			virtual ~EndpointI(void) {} |  | ||||||
|  |  | ||||||
| 			virtual bool handleResponse(Contact4 contact, ByteSpan data) = 0; |  | ||||||
| 		}; |  | ||||||
|  |  | ||||||
| 	private: | 	private: | ||||||
| 		std::unique_ptr<EndpointI> _endpoint; | 		std::unique_ptr<WebAPII> _endpoint; | ||||||
|  |  | ||||||
| 	public: | 	public: | ||||||
| 		SDBot( | 		SDBot( | ||||||
| @@ -64,9 +50,6 @@ class SDBot : public RegistryMessageModelEventI { | |||||||
| 		bool use_webp_for_friends = true; | 		bool use_webp_for_friends = true; | ||||||
| 		bool use_webp_for_groups = true; | 		bool use_webp_for_groups = true; | ||||||
|  |  | ||||||
| 	//protected: // tox events |  | ||||||
| 		//bool onToxEvent(const Tox_Event_Friend_Message* e) override; |  | ||||||
| 		//bool onToxEvent(const Tox_Event_Group_Message* e) override; |  | ||||||
| 	protected: // mm | 	protected: // mm | ||||||
| 		bool onEvent(const Message::Events::MessageConstruct& e) override; | 		bool onEvent(const Message::Events::MessageConstruct& e) override; | ||||||
| }; | }; | ||||||
|   | |||||||
							
								
								
									
										40
									
								
								src/webapi_interface.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/webapi_interface.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <string_view> | ||||||
|  | #include <cstdint> | ||||||
|  | #include <memory> | ||||||
|  | #include <optional> | ||||||
|  | #include <vector> | ||||||
|  | #include <string> | ||||||
|  |  | ||||||
|  | struct WebAPITaskI; | ||||||
|  |  | ||||||
|  | struct WebAPII { | ||||||
|  | 	virtual ~WebAPII(void) {} | ||||||
|  |  | ||||||
|  | 	virtual std::shared_ptr<WebAPITaskI> txt2img( | ||||||
|  | 		std::string_view prompt, | ||||||
|  | 		int16_t width, | ||||||
|  | 		int16_t height | ||||||
|  | 		// more | ||||||
|  | 	) = 0; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | // only knows of a single prompt | ||||||
|  | struct WebAPITaskI { | ||||||
|  | 	virtual ~WebAPITaskI(void) {} | ||||||
|  |  | ||||||
|  | 	// true if done or failed | ||||||
|  | 	virtual bool ready(void) = 0; | ||||||
|  |  | ||||||
|  | 	struct Result { | ||||||
|  | 		std::vector<uint8_t> data; | ||||||
|  | 		int16_t width {}; | ||||||
|  | 		int16_t height {}; | ||||||
|  |  | ||||||
|  | 		std::string file_name; | ||||||
|  | 	}; | ||||||
|  |  | ||||||
|  | 	virtual std::optional<Result> get(void) = 0; | ||||||
|  | }; | ||||||
|  |  | ||||||
							
								
								
									
										228
									
								
								src/webapi_sdcpp_stduhpf_wip2.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								src/webapi_sdcpp_stduhpf_wip2.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,228 @@ | |||||||
|  | #include "./webapi_sdcpp_stduhpf_wip2.hpp" | ||||||
|  |  | ||||||
|  | #include <httplib.h> | ||||||
|  | #include <nlohmann/json.hpp> | ||||||
|  | #include <sodium.h> | ||||||
|  |  | ||||||
|  | #include <cstdint> | ||||||
|  | #include <memory> | ||||||
|  | #include <chrono> | ||||||
|  | #include <string> | ||||||
|  |  | ||||||
|  | std::shared_ptr<httplib::Client> WebAPI_sdcpp_stduhpf_wip2::getCl(void) { | ||||||
|  | 	if (_cl == nullptr) { | ||||||
|  | 		const std::string server_host {_conf.get_string("SDBot", "server_host").value()}; | ||||||
|  | 		_cl = std::make_shared<httplib::Client>(server_host, _conf.get_int("SDBot", "server_port").value()); | ||||||
|  | 		_cl->set_read_timeout(std::chrono::minutes(2)); // because of discarding futures, it can block main for a while | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return _cl; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | WebAPI_sdcpp_stduhpf_wip2::WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf) : | ||||||
|  | 	_conf(conf) | ||||||
|  | { | ||||||
|  | } | ||||||
|  |  | ||||||
|  | WebAPI_sdcpp_stduhpf_wip2::~WebAPI_sdcpp_stduhpf_wip2(void) { | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | std::shared_ptr<WebAPITaskI> WebAPI_sdcpp_stduhpf_wip2::txt2img( | ||||||
|  | 	std::string_view prompt, | ||||||
|  | 	int16_t width, | ||||||
|  | 	int16_t height | ||||||
|  | 	// more | ||||||
|  | ) { | ||||||
|  | 	auto cl = getCl(); | ||||||
|  | 	if (!cl) { | ||||||
|  | 		return nullptr; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	nlohmann::json j_body; | ||||||
|  | 	// eg | ||||||
|  | 	//{ | ||||||
|  | 	//  "prompt": "a lovely schnitzel", | ||||||
|  | 	//  "negative_prompt": "", | ||||||
|  | 	//  "width": 512, | ||||||
|  | 	//  "height": 512, | ||||||
|  | 	//  "guidance_params": { | ||||||
|  | 	//    "cfg_scale": 1, | ||||||
|  | 	//    "guidance": 3.5 | ||||||
|  | 	//  }, | ||||||
|  | 	//  "sample_steps": 8, | ||||||
|  | 	//  "sample_method": "euler_a", | ||||||
|  | 	//  "seed": -1, | ||||||
|  | 	//  "batch_count": 1, | ||||||
|  | 	//  "schedule": "default", | ||||||
|  | 	//  "model": -1, | ||||||
|  | 	//  "diffusion_model": -1, | ||||||
|  | 	//  "clip_l": -1, | ||||||
|  | 	//  "clip_g": -1, | ||||||
|  | 	//  "t5xxl": -1, | ||||||
|  | 	//  "vae": -1, | ||||||
|  | 	//  "tae": -1, | ||||||
|  | 	//  "keep_vae_on_cpu": false, | ||||||
|  | 	//  "kep_clip_on_cpu": false, | ||||||
|  | 	//  "vae_tiling": true, | ||||||
|  | 	//  "tae_decode": true | ||||||
|  | 	//} | ||||||
|  | 	//j_body["width"] = _conf.get_int("SDBot", "width").value_or(512); | ||||||
|  | 	j_body["width"] = width; | ||||||
|  | 	//j_body["height"] = _conf.get_int("SDBot", "height").value_or(512); | ||||||
|  | 	j_body["height"] = height; | ||||||
|  |  | ||||||
|  | 	j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + std::string{prompt}; | ||||||
|  | 	// TODO: negative prompt | ||||||
|  |  | ||||||
|  | 	j_body["seed"] = -1; | ||||||
|  | 	j_body["sample_steps"] = _conf.get_int("SDBot", "steps").value_or(20); | ||||||
|  | 	j_body["guidance_params"]["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5); | ||||||
|  | 	//j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")}; | ||||||
|  |  | ||||||
|  | 	std::string body = j_body.dump(); | ||||||
|  |  | ||||||
|  | 	const std::string url {_conf.get_string("SDBot", "url_txt2img").value_or("/txt2img")}; | ||||||
|  |  | ||||||
|  | 	try { | ||||||
|  | 		// not restful -> returns imediatly with task id | ||||||
|  | 		auto res = cl->Post(url, body, "application/json"); | ||||||
|  | 		if (!static_cast<bool>(res)) { | ||||||
|  | 			std::cerr << "SDB error: post to sd server failed!\n"; | ||||||
|  | 			return nullptr; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		std::cerr << "SDB http complete " << res->status << " " << res->reason << "\n"; | ||||||
|  |  | ||||||
|  | 		if ( | ||||||
|  | 			res.error() != httplib::Error::Success || | ||||||
|  | 			res->status != 200 | ||||||
|  | 		) { | ||||||
|  | 			return nullptr; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | 		try { | ||||||
|  | 			// {"task_id":"1753652554405483588"} | ||||||
|  | 			auto j_res = nlohmann::json::parse(res->body); | ||||||
|  |  | ||||||
|  | 			if (!j_res.contains("task_id")) { | ||||||
|  | 				std::cerr << "SDB error: response not a task_id\n"; | ||||||
|  | 				return nullptr; | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			uint64_t task_id{}; | ||||||
|  | 			const auto& j_task_id = j_res.at("task_id"); | ||||||
|  | 			if (!j_task_id.is_number_unsigned()) { | ||||||
|  | 				// meh, conversion time | ||||||
|  | 				task_id = std::stoull(j_task_id.get<std::string>().c_str()); | ||||||
|  | 				//std::cout << "converted " << j_task_id << " to " << task_id << "\n"; | ||||||
|  | 			} else { | ||||||
|  | 				task_id = j_task_id.get<uint64_t>(); | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			std::cout << "SDB: sdcpp task id: " << task_id << "\n"; | ||||||
|  |  | ||||||
|  | 			return std::make_shared<WebAPITask_sdcpp_stduhpf_wip2>( | ||||||
|  | 				cl, | ||||||
|  | 				"/result", // TODO: from conf? | ||||||
|  | 				task_id | ||||||
|  | 			); | ||||||
|  | 		} catch (...) { | ||||||
|  | 			std::cerr << "SDB error: failed parsing response as json\n"; | ||||||
|  | 			return nullptr; | ||||||
|  | 		} | ||||||
|  | 	} catch (...) { | ||||||
|  | 		std::cerr << "SDB http request error\n"; | ||||||
|  | 		return nullptr; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ??? | ||||||
|  | 	return nullptr; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | WebAPITask_sdcpp_stduhpf_wip2::WebAPITask_sdcpp_stduhpf_wip2(std::shared_ptr<httplib::Client> cl, const std::string& url, uint64_t task_id) | ||||||
|  | 	: _cl(cl), _url(url), _task_id(task_id) | ||||||
|  | { | ||||||
|  | } | ||||||
|  |  | ||||||
|  | //WebAPITask_sdcpp_stduhpf_wip2::~WebAPITask_sdcpp_stduhpf_wip2(void) { | ||||||
|  | //} | ||||||
|  |  | ||||||
|  | bool WebAPITask_sdcpp_stduhpf_wip2::ready(void) { | ||||||
|  | 	// polling api | ||||||
|  |  | ||||||
|  | 	try { | ||||||
|  | 		auto res = _cl->Get(_url + "?task_id=" + std::to_string(_task_id)); | ||||||
|  | 		if (!static_cast<bool>(res)) { | ||||||
|  | 			std::cerr << "SDB error: post to sd server failed! (in task)\n"; | ||||||
|  | 			_done = true; | ||||||
|  | 			return true; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if ( | ||||||
|  | 			res.error() != httplib::Error::Success || | ||||||
|  | 			res->status != 200 | ||||||
|  | 		) { | ||||||
|  | 			std::cerr << "SDB error: post to sd server failed! (in task)\n"; | ||||||
|  | 			_done = true; | ||||||
|  | 			return true; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		auto j_res = nlohmann::json::parse(res->body); | ||||||
|  | 		if (j_res.at("status") != "Done") { | ||||||
|  | 			// not ready (likely path) | ||||||
|  | 			return false; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// TODO: add support for multiple images | ||||||
|  | 		auto& j_image = j_res.at("data").at(0); // ?? | ||||||
|  |  | ||||||
|  | 		_result.width = j_image.at("width"); | ||||||
|  | 		_result.height = j_image.at("height"); | ||||||
|  |  | ||||||
|  | 		{ // data | ||||||
|  | 			auto& j_data = j_image.at("data"); | ||||||
|  |  | ||||||
|  | 			_result.data.resize(j_data.get<std::string>().size()); // HACK: do a better estimate | ||||||
|  | 			size_t decoded_size {0}; | ||||||
|  | 			sodium_base642bin( | ||||||
|  | 				_result.data.data(), _result.data.size(), | ||||||
|  | 				j_data.get<std::string>().data(), j_data.get<std::string>().size(), | ||||||
|  | 				" \n\t", | ||||||
|  | 				&decoded_size, | ||||||
|  | 				nullptr, | ||||||
|  | 				sodium_base64_VARIANT_ORIGINAL | ||||||
|  | 			); | ||||||
|  | 			_result.data.resize(decoded_size); | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		{ // file name? | ||||||
|  | 			// == "png" or jpeg or somehting | ||||||
|  | 			_result.file_name = std::string{"output_0."} + j_image.value("encoding", "png"); | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		_done = true; | ||||||
|  | 		return true; | ||||||
|  | 	} catch (...) { | ||||||
|  | 		// rip | ||||||
|  | 		_done = true; | ||||||
|  | 		return true; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return false; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | std::optional<WebAPITaskI::Result> WebAPITask_sdcpp_stduhpf_wip2::get(void) { | ||||||
|  | 	if (!_done) { | ||||||
|  | 		assert(false && "what ya doin?"); | ||||||
|  | 		return std::nullopt; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if (_result.data.empty()) { | ||||||
|  | 		return std::nullopt; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return std::move(_result); | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										52
									
								
								src/webapi_sdcpp_stduhpf_wip2.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								src/webapi_sdcpp_stduhpf_wip2.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "./webapi_interface.hpp" | ||||||
|  |  | ||||||
|  | #include <cstdint> | ||||||
|  | #include <solanaceae/util/config_model.hpp> | ||||||
|  |  | ||||||
|  | #include <memory> | ||||||
|  |  | ||||||
|  | // fwd | ||||||
|  | namespace httplib { | ||||||
|  | 	class Client; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // in its second api iteration, stduhpf switched away from a rest api | ||||||
|  | struct WebAPI_sdcpp_stduhpf_wip2 : public WebAPII { | ||||||
|  | 	// TODO: const config | ||||||
|  | 	ConfigModelI& _conf; | ||||||
|  |  | ||||||
|  | 	std::shared_ptr<httplib::Client> _cl; | ||||||
|  | 	std::shared_ptr<httplib::Client> getCl(void); | ||||||
|  |  | ||||||
|  | 	WebAPI_sdcpp_stduhpf_wip2(ConfigModelI& conf); | ||||||
|  |  | ||||||
|  | 	~WebAPI_sdcpp_stduhpf_wip2(void) override; | ||||||
|  |  | ||||||
|  | 	std::shared_ptr<WebAPITaskI> txt2img( | ||||||
|  | 		std::string_view prompt, | ||||||
|  | 		int16_t width, | ||||||
|  | 		int16_t height | ||||||
|  | 		// more | ||||||
|  | 	) override; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | // discard after get() !! | ||||||
|  | struct WebAPITask_sdcpp_stduhpf_wip2 : public WebAPITaskI { | ||||||
|  | 	std::shared_ptr<httplib::Client> _cl; | ||||||
|  | 	const std::string _url; | ||||||
|  |  | ||||||
|  | 	uint64_t _task_id {}; | ||||||
|  | 	bool _done {false}; | ||||||
|  | 	Result _result; | ||||||
|  |  | ||||||
|  | 	WebAPITask_sdcpp_stduhpf_wip2(std::shared_ptr<httplib::Client> cl, const std::string& url, uint64_t task_id); | ||||||
|  |  | ||||||
|  | 	~WebAPITask_sdcpp_stduhpf_wip2(void) override {} | ||||||
|  |  | ||||||
|  | 	bool ready(void) override; | ||||||
|  | 	std::optional<Result> get(void) override; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  |  | ||||||
		Reference in New Issue
	
	Block a user