diff --git a/src/sd_bot.cpp b/src/sd_bot.cpp index d4b4bf1..1155c37 100644 --- a/src/sd_bot.cpp +++ b/src/sd_bot.cpp @@ -1,5 +1,7 @@ #include "./sd_bot.hpp" +#include + #include #include @@ -17,6 +19,22 @@ SDBot::SDBot( RegistryMessageModel& rmm, ConfigModelI& conf ) : _cr(cr), _rmm(rmm), _conf(conf) { + _rng.seed(std::random_device{}()); + _rng.discard(3137); + + if (!_conf.has_int("SDBot", "width")) { + _conf.set("SDBot", "width", int64_t(512)); + } + if (!_conf.has_int("SDBot", "height")) { + _conf.set("SDBot", "height", int64_t(512)); + } + if (!_conf.has_int("SDBot", "steps")) { + _conf.set("SDBot", "steps", int64_t(20)); + } + if (!_conf.has_double("SDBot", "cfg_scale")) { + _conf.set("SDBot", "cfg_scale", 6.5); + } + _rmm.subscribe(this, RegistryMessageModel_Event::message_construct); } @@ -48,25 +66,28 @@ void SDBot::iterate(void) { }; nlohmann::json j_body; - // TODO: read from config -#if 1 - j_body["width"] = 512; - j_body["height"] = 512; -#elif 0 - j_body["width"] = 768; - j_body["height"] = 768; -#else - j_body["width"] = 128; - j_body["height"] = 128; -#endif + j_body["width"] = _conf.get_int("SDBot", "width").value_or(512); + j_body["height"] = _conf.get_int("SDBot", "height").value_or(512); - j_body["prompt"] = prompt; + //j_body["prompt"] = prompt; + //"" + j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + prompt; + // TODO: negative prompt j_body["seed"] = -1; +#if 0 j_body["steps"] = 20; //j_body["steps"] = 5; j_body["cfg_scale"] = 6.5; j_body["sampler_index"] = "Euler a"; +#else + //j_body["steps"] = 4; + j_body["steps"] = _conf.get_int("SDBot", "steps").value_or(20); + //j_body["cfg_scale"] = 1; + j_body["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5); + //j_body["sampler_index"] = "LCM Test"; + j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")}; +#endif j_body["batch_size"] = 1; j_body["n_iter"] = 1; @@ -206,7 +227,8 @@ void SDBot::onHttpComplete(const happyhttp::Response* r) { const auto& contact = _task_map.at(_current_task.value()); std::filesystem::create_directories("sdbot_img_send"); - const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; + //const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png"; + const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png"; const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name; std::ofstream(tmp_img_file_path).write(reinterpret_cast(png_data.data()), png_data.size()); diff --git a/src/sd_bot.hpp b/src/sd_bot.hpp index 1f2d1f5..7f01ea2 100644 --- a/src/sd_bot.hpp +++ b/src/sd_bot.hpp @@ -10,6 +10,7 @@ #include #include #include +#include // fwd struct ConfigModelI; @@ -30,6 +31,8 @@ class SDBot : public RegistryMessageModelEventI { std::unique_ptr _con; std::vector _con_data; + std::default_random_engine _rng; + public: SDBot( Contact3Registry& cr,