make params configurable
This commit is contained in:
parent
070c573f3e
commit
08256285f6
@ -1,5 +1,7 @@
|
|||||||
#include "./sd_bot.hpp"
|
#include "./sd_bot.hpp"
|
||||||
|
|
||||||
|
#include <solanaceae/util/config_model.hpp>
|
||||||
|
|
||||||
#include <solanaceae/contact/components.hpp>
|
#include <solanaceae/contact/components.hpp>
|
||||||
#include <solanaceae/message3/components.hpp>
|
#include <solanaceae/message3/components.hpp>
|
||||||
|
|
||||||
@ -17,6 +19,22 @@ SDBot::SDBot(
|
|||||||
RegistryMessageModel& rmm,
|
RegistryMessageModel& rmm,
|
||||||
ConfigModelI& conf
|
ConfigModelI& conf
|
||||||
) : _cr(cr), _rmm(rmm), _conf(conf) {
|
) : _cr(cr), _rmm(rmm), _conf(conf) {
|
||||||
|
_rng.seed(std::random_device{}());
|
||||||
|
_rng.discard(3137);
|
||||||
|
|
||||||
|
if (!_conf.has_int("SDBot", "width")) {
|
||||||
|
_conf.set("SDBot", "width", int64_t(512));
|
||||||
|
}
|
||||||
|
if (!_conf.has_int("SDBot", "height")) {
|
||||||
|
_conf.set("SDBot", "height", int64_t(512));
|
||||||
|
}
|
||||||
|
if (!_conf.has_int("SDBot", "steps")) {
|
||||||
|
_conf.set("SDBot", "steps", int64_t(20));
|
||||||
|
}
|
||||||
|
if (!_conf.has_double("SDBot", "cfg_scale")) {
|
||||||
|
_conf.set("SDBot", "cfg_scale", 6.5);
|
||||||
|
}
|
||||||
|
|
||||||
_rmm.subscribe(this, RegistryMessageModel_Event::message_construct);
|
_rmm.subscribe(this, RegistryMessageModel_Event::message_construct);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,25 +66,28 @@ void SDBot::iterate(void) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
nlohmann::json j_body;
|
nlohmann::json j_body;
|
||||||
// TODO: read from config
|
j_body["width"] = _conf.get_int("SDBot", "width").value_or(512);
|
||||||
#if 1
|
j_body["height"] = _conf.get_int("SDBot", "height").value_or(512);
|
||||||
j_body["width"] = 512;
|
|
||||||
j_body["height"] = 512;
|
|
||||||
#elif 0
|
|
||||||
j_body["width"] = 768;
|
|
||||||
j_body["height"] = 768;
|
|
||||||
#else
|
|
||||||
j_body["width"] = 128;
|
|
||||||
j_body["height"] = 128;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
j_body["prompt"] = prompt;
|
//j_body["prompt"] = prompt;
|
||||||
|
//"<lora:lcm-lora-sdv1-5:1>"
|
||||||
|
j_body["prompt"] = std::string{_conf.get_string("SDBot", "prompt_prefix").value_or("")} + prompt;
|
||||||
|
// TODO: negative prompt
|
||||||
|
|
||||||
j_body["seed"] = -1;
|
j_body["seed"] = -1;
|
||||||
|
#if 0
|
||||||
j_body["steps"] = 20;
|
j_body["steps"] = 20;
|
||||||
//j_body["steps"] = 5;
|
//j_body["steps"] = 5;
|
||||||
j_body["cfg_scale"] = 6.5;
|
j_body["cfg_scale"] = 6.5;
|
||||||
j_body["sampler_index"] = "Euler a";
|
j_body["sampler_index"] = "Euler a";
|
||||||
|
#else
|
||||||
|
//j_body["steps"] = 4;
|
||||||
|
j_body["steps"] = _conf.get_int("SDBot", "steps").value_or(20);
|
||||||
|
//j_body["cfg_scale"] = 1;
|
||||||
|
j_body["cfg_scale"] = _conf.get_double("SDBot", "cfg_scale").value_or(6.5);
|
||||||
|
//j_body["sampler_index"] = "LCM Test";
|
||||||
|
j_body["sampler_index"] = std::string{_conf.get_string("SDBot", "sampler_index").value_or("Euler a")};
|
||||||
|
#endif
|
||||||
|
|
||||||
j_body["batch_size"] = 1;
|
j_body["batch_size"] = 1;
|
||||||
j_body["n_iter"] = 1;
|
j_body["n_iter"] = 1;
|
||||||
@ -206,7 +227,8 @@ void SDBot::onHttpComplete(const happyhttp::Response* r) {
|
|||||||
const auto& contact = _task_map.at(_current_task.value());
|
const auto& contact = _task_map.at(_current_task.value());
|
||||||
|
|
||||||
std::filesystem::create_directories("sdbot_img_send");
|
std::filesystem::create_directories("sdbot_img_send");
|
||||||
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
|
//const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_current_task.value()) + ".png";
|
||||||
|
const std::string tmp_img_file_name = "sdbot_img_" + std::to_string(_rng()) + ".png";
|
||||||
const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
|
const std::string tmp_img_file_path = "sdbot_img_send/" + tmp_img_file_name;
|
||||||
|
|
||||||
std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
|
std::ofstream(tmp_img_file_path).write(reinterpret_cast<const char*>(png_data.data()), png_data.size());
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
// fwd
|
// fwd
|
||||||
struct ConfigModelI;
|
struct ConfigModelI;
|
||||||
@ -30,6 +31,8 @@ class SDBot : public RegistryMessageModelEventI {
|
|||||||
std::unique_ptr<happyhttp::Connection> _con;
|
std::unique_ptr<happyhttp::Connection> _con;
|
||||||
std::vector<uint8_t> _con_data;
|
std::vector<uint8_t> _con_data;
|
||||||
|
|
||||||
|
std::default_random_engine _rng;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SDBot(
|
SDBot(
|
||||||
Contact3Registry& cr,
|
Contact3Registry& cr,
|
||||||
|
Loading…
Reference in New Issue
Block a user