make params configurable

This commit is contained in:
Green Sky 2023-12-06 20:29:18 +01:00
parent 070c573f3e
commit 08256285f6
No known key found for this signature in database
2 changed files with 38 additions and 13 deletions

View File

@ -1,5 +1,7 @@
#include "./sd_bot.hpp"
#include <solanaceae/util/config_model.hpp>
#include <solanaceae/contact/components.hpp>
#include <solanaceae/message3/components.hpp>
@ -17,6 +19,22 @@ SDBot::SDBot(
RegistryMessageModel& rmm,
ConfigModelI& conf
) : _cr(cr), _rmm(rmm), _conf(conf) {
_rng.seed(std::random_device{}());
_rng.discard(3137);
if (!_conf.has_int("SDBot", "width")) {
_conf.set("SDBot", "width", int64_t(512));
}
if (!_conf.has_int("SDBot", "height")) {
_conf.set("SDBot", "height", int64_t(512));
}
if (!_conf.has_int("SDBot", "steps")) {
_conf.set("SDBot", "steps", int64_t(20));
}
if (!_conf.has_double("SDBot", "cfg_scale")) {
_conf.set("SDBot", "cfg_scale", 6.5);
}
_rmm.subscribe(this, RegistryMessageModel_Event::message_construct);
}
@ -48,25 +66,28 @@ void SDBot::iterate(void) {
};
nlohmann::json j_body;
// TODO: read from config
#if 1
j_body["width"] = 512;
j_body["height"] = 512;
#elif 0
j_body["width"] = 768;
j_body["height"] = 768;
#else
j_body["width"] = 128;
j_body["height"] = 128;
#endif
j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
j_body["prompt"] = prompt;
//j_body["prompt"] = prompt;
//"<lora:lcm-lora-sdv1-5:1>"
j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + prompt;
// TODO: negative prompt
j_body["seed"] = -1;
#if 0
j_body["steps"] = 20;
//j_body["steps"] = 5;
j_body["cfg_scale"] = 6.5;
j_body["sampler_index"] = "Euler a";
#else
//j_body["steps"] = 4;
j_body["steps"] = _conf.get_int("SDBot", "steps").value_or(20);
//j_body["cfg_scale"] = 1;
j_body["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5);
//j_body["sampler_index"] = "LCM Test";
j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")};
#endif
j_body["batch_size"] = 1;
j_body["n_iter"] = 1;
@ -206,7 +227,8 @@ void SDBot::onHttpComplete(const happyhttp::Response* r) {
const auto& contact = _task_map.at(_current_task.value());
std::filesystem::create_directories("sdbot_img_send");
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());

View File

@ -10,6 +10,7 @@
#include <string>
#include <memory>
#include <optional>
#include <random>
// fwd
struct ConfigModelI;
@ -30,6 +31,8 @@ class SDBot : public RegistryMessageModelEventI {
std::unique_ptr<happyhttp::Connection> _con;
std::vector<uint8_t> _con_data;
std::default_random_engine _rng;
public:
SDBot(
Contact3Registry& cr,