From 83264db09de25c2311a5a5600c7f538a773b19ac Mon Sep 17 00:00:00 2001 From: Green Sky Date: Wed, 24 Jan 2024 22:55:18 +0100 Subject: [PATCH] rpbot should* work? --- plugins/plugin_rpbot.cpp | 6 +- src/CMakeLists.txt | 2 + .../llama-cpp-web/llama_cpp_web_impl.cpp | 2 + .../llama-cpp-web/llama_cpp_web_impl.hpp | 5 +- .../rpbot/message_prompt_builder.hpp | 112 ++++++ src/solanaceae/rpbot/rpbot.cpp | 330 +++++++++++++++++- src/solanaceae/rpbot/rpbot.hpp | 47 ++- 7 files changed, 498 insertions(+), 6 deletions(-) create mode 100644 src/solanaceae/rpbot/message_prompt_builder.hpp diff --git a/plugins/plugin_rpbot.cpp b/plugins/plugin_rpbot.cpp index fd7b2f6..3723ead 100644 --- a/plugins/plugin_rpbot.cpp +++ b/plugins/plugin_rpbot.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -32,10 +33,13 @@ SOLANA_PLUGIN_EXPORT uint32_t solana_plugin_start(struct SolanaAPI* solana_api) try { auto* completion = PLUG_RESOLVE_INSTANCE(TextCompletionI); auto* conf = PLUG_RESOLVE_INSTANCE(ConfigModelI); + auto* cr = PLUG_RESOLVE_INSTANCE_VERSIONED(Contact3Registry, "1"); + auto* rmm = PLUG_RESOLVE_INSTANCE(RegistryMessageModel); + auto* mcd = PLUG_RESOLVE_INSTANCE(MessageCommandDispatcher); // static store, could be anywhere tho // construct with fetched dependencies - g_rpbot = std::make_unique(*completion, *conf); + g_rpbot = std::make_unique(*completion, *conf, *cr, *rmm, mcd); // register types PLUG_PROVIDE_INSTANCE(RPBot, plugin_name, g_rpbot.get()); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5692458..a4bda44 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,6 +31,8 @@ target_link_libraries(test1 PUBLIC ######################################## add_library(solanaceae_rpbot + ./solanaceae/rpbot/message_prompt_builder.hpp + ./solanaceae/rpbot/rpbot.hpp ./solanaceae/rpbot/rpbot.cpp ) diff --git a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp index 64da408..b54c188 100644 --- a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp +++ b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp @@ -6,6 +6,8 @@ #include +std::minstd_rand thread_local LlamaCppWeb::_rng{std::random_device{}()}; + // TODO: variant that strips unicode? static std::string convertToSafeGrammarString(std::string_view input) { std::string res; diff --git a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp index 9f23ea7..7619130 100644 --- a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp +++ b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp @@ -8,8 +8,11 @@ #include struct LlamaCppWeb : public TextCompletionI { + // this mutex locks internally httplib::Client _cli{"http://localhost:8080"}; - std::minstd_rand _rng{std::random_device{}()}; + + // this is a bad idea + static std::minstd_rand thread_local _rng; ~LlamaCppWeb(void); diff --git a/src/solanaceae/rpbot/message_prompt_builder.hpp b/src/solanaceae/rpbot/message_prompt_builder.hpp new file mode 100644 index 0000000..290d4cf --- /dev/null +++ b/src/solanaceae/rpbot/message_prompt_builder.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include "./rpbot.hpp" + +#include +#include + +#include +#include + +// TODO: improve caching +struct MessagePromptBuilder { + Contact3Registry& _cr; + const Contact3 _c; + RegistryMessageModel& _rmm; + + // lookup table, string_view since no name-components are changed + entt::dense_map names; + + + bool buildNameLookup(void) { + if (_cr.all_of(_c)) { // group rpbot + const auto& subs = _cr.get(_c).subs; + // should include self + for (const auto sub_c : subs) { + if (_cr.all_of(sub_c)) { + names[sub_c] = _cr.get(sub_c).name; + } + } + } else { // pm rpbot + if (_cr.all_of(_c)) { + names[_c] = _cr.get(_c).name; + } else { + std::cerr << "RPBot error: other missing name\n"; + return false; + } + + if (_cr.all_of(_c)) { + const auto self = _cr.get(_c).self; + if (_cr.all_of(self)) { + names[self] = _cr.get(self).name; + } else { + std::cerr << "RPBot error: self missing name\n"; + return false; + } + } else { + std::cerr << "RPBot error: contact missing self\n"; + return false; + } + } + + return true; + } + + std::string buildPromptMessageHistory(void) { + auto* mr = _rmm.get(_c); + assert(mr); + + std::string prompt; + + auto view = mr->view(); + for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) { + const Message3 e = *view_it; + + // manually filter ("reverse" iteration <.<) + // TODO: add mesagetext? + if (!mr->all_of(e)) { + continue; + } + + //Message::Components::ContactFrom& c_from = mr->get(e); + //Message::Components::ContactTo& c_to = msg_reg.get(e); + //Message::Components::Timestamp ts = view.get(e); + + prompt += "\n"; + prompt += buildPromptMessage({*mr, e}); + } + + return prompt; + } + + // gets split across lines + std::string buildPromptMessage(const Message3Handle m) { + if (!m.all_of()) { + // TODO: case for transfers + return ""; + } + + // TODO: cache as comp + const auto line_prefix = buildPromptMessagePrefix(m); + + // TODO: trim + std::string message_lines = line_prefix + m.get().text; + for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) { + message_lines.insert(nlpos+1, line_prefix); + nlpos += line_prefix.size(); + } + + return message_lines; + } + + // generate prompt prefix (timestamp + name), configurable? + std::string buildPromptMessagePrefix(const Message3Handle m) { + const Contact3 from = m.get().c; + if (names.count(from)) { + return std::string{names[from]} + ": "; + } else { + return ": "; + } + } +}; + diff --git a/src/solanaceae/rpbot/rpbot.cpp b/src/solanaceae/rpbot/rpbot.cpp index 2c4add7..87a36c6 100644 --- a/src/solanaceae/rpbot/rpbot.cpp +++ b/src/solanaceae/rpbot/rpbot.cpp @@ -1,12 +1,336 @@ #include "./rpbot.hpp" +#include "./message_prompt_builder.hpp" +#include "solanaceae/contact/contact_model3.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// sleeps until onMsg or onTimer +struct StateIdle { + static constexpr const char* name {"StateIdle"}; + float timeout {0.f}; +}; + +// determines if self should generate a message +struct StateNext { + static constexpr const char* name {"StateNext"}; + + std::string prompt; + std::vector possible_names; + std::vector possible_contacts; + + std::future future; +}; + +// generate message +struct StateGenerateMsg { + static constexpr const char* name {"StateGenerateMsg"}; + + std::string prompt; + + // returns new line (single message) + std::future future; +}; + +// look if it took too long/too many new messages came in +// while also optionally sleeping to make message appear not too fast +// HACK: skip, just send for now +struct StateTimingCheck { + static constexpr const char* name {"StateTimingCheck"}; + int tmp; +}; + +template<> +void RPBot::stateTransition(const Contact3 c, const StateIdle& from, StateNext& to) { + // collect promp + + { // - system promp + to.prompt = system_prompt; + } + + MessagePromptBuilder mpb{_cr, c, _rmm, {}}; + { // - message history + mpb.buildNameLookup(); + to.prompt += mpb.buildPromptMessageHistory(); + } + + { // - next needs the beginning of the new message + // empty rn + to.prompt += "\n"; + } + + std::cout << "prompt for next: '" << to.prompt << "'\n"; + + { // get set of possible usernames + // copy mpb.names (contains string views, needs copies) + for (const auto& [name_c, name] : mpb.names) { + if (_cr.all_of(name_c)) { + if (_cr.get(name_c).state != Contact::Components::ConnectionState::disconnected) { + // online + to.possible_names.push_back(std::string{name}); + to.possible_contacts.push_back(name_c); + } + } else if (_cr.all_of(name_c)) { + to.possible_names.push_back(std::string{name}); + to.possible_contacts.push_back(name_c); + } + } + } + + { // launch async + // copy names for string view param (lol) + std::vector pnames; + for (const auto& n : to.possible_names) { + pnames.push_back(n); + } + + to.future = std::async(std::launch::async, [pnames, &to, this]() -> int64_t { + return _completion.completeSelect(to.prompt, pnames); + }); + } +} + +template<> +void RPBot::stateTransition(const Contact3, const StateNext&, StateIdle& to) { + to.timeout = std::uniform_real_distribution<>{15.f, 45.f}(_rng); +} + +template<> +void RPBot::stateTransition(const Contact3 c, const StateNext& from, StateGenerateMsg& to) { + to.prompt = from.prompt; // TODO: move from? + assert(_cr.all_of(c)); + const Contact3 self = _cr.get(c).self; + + to.prompt += _cr.get(self).name + ": "; // TODO: remove space + + { // launch async + to.future = std::async(std::launch::async, [&to, this]() -> std::string { + return _completion.completeLine(to.prompt); + }); + } +} + +template<> +void RPBot::stateTransition(const Contact3, const StateGenerateMsg&, StateIdle& to) { + to.timeout = std::uniform_real_distribution<>{10.f, 30.f}(_rng); +} + RPBot::RPBot( TextCompletionI& completion, - ConfigModelI& conf -) : _completion(completion), _conf(conf) { + ConfigModelI& conf, + Contact3Registry& cr, + RegistryMessageModel& rmm, + MessageCommandDispatcher* mcd +) : _completion(completion), _conf(conf), _cr(cr), _rmm(rmm), _mcd(mcd) { + //system_prompt = R"sys(Transcript of a group chat, where Bob talks to online strangers. +//)sys"; + system_prompt = "Transcript of a group chat, where "; + if (_conf.has_string("tox", "name")) { + system_prompt += _conf.get_string("tox", "name").value(); + } else { + system_prompt += std::string{"Bob"}; + } + system_prompt += std::string{" talks to online strangers.\n"}; + + registerCommands(); } float RPBot::tick(float time_delta) { - return 10.f; + float min_tick_interval = std::numeric_limits::max(); + + min_tick_interval = std::min(min_tick_interval, doAllIdle(time_delta)); + min_tick_interval = std::min(min_tick_interval, doAllNext(time_delta)); + min_tick_interval = std::min(min_tick_interval, doAllGenerateMsg(time_delta)); + min_tick_interval = std::min(min_tick_interval, doAllTimingCheck(time_delta)); + + return min_tick_interval; +} + +void RPBot::registerCommands(void) { + if (_mcd == nullptr) { + return; + } + + _mcd->registerCommand( + "RPBot", "rpbot", + "start", + [this](std::string_view params, Message3Handle m) -> bool { + const auto contact_from = m.get().c; + const auto contact_to = m.get().c; + + if (params.empty()) { + // contact_to should be the contact this is for + if (_cr.any_of(contact_to)) { + _rmm.sendText( + contact_from, + "error: already running" + ); + return true; + } + if (_cr.any_of(contact_from)) { + _rmm.sendText( + contact_from, + "error: already running" + ); + return true; + } + + if (_cr.all_of(contact_to)) { + // group + auto& new_state = _cr.emplace(contact_to); + new_state.timeout = 10.f; + } else { + // pm + auto& new_state = _cr.emplace(contact_from); + new_state.timeout = 10.f; + } + + _rmm.sendText( + contact_from, + "RPBot started" + ); + return true; + } else { + // id in params + if (params.size() % 2 != 0) { + _rmm.sendText( + contact_from, + "malformed hex id" + ); + return true; + } + + auto id_bin = hex2bin(params); + + auto view = _cr.view(); + for (auto it = view.begin(), it_end = view.end(); it != it_end; it++) { + if (view.get(*it).data == id_bin) { + auto& new_state = _cr.emplace(*it); + new_state.timeout = 10.f; + + _rmm.sendText( + contact_from, + "RPBot started" + ); + return true; + } + } + + _rmm.sendText( + contact_from, + "no contact found for id" + ); + return true; + } + }, + "Start RPBot in current contact.", + MessageCommandDispatcher::Perms::ADMIN // TODO: should proably be MODERATOR + ); + + std::cout << "RPBot: registered commands\n"; +} + +float RPBot::doAllIdle(float time_delta) { + float min_tick_interval = std::numeric_limits::max(); + std::vector to_remove_stateidle; + auto view = _cr.view(); + + view.each([this, time_delta, &to_remove_stateidle, &min_tick_interval](const Contact3 c, StateIdle& state) { + state.timeout -= time_delta; + if (state.timeout <= 0.f) { + std::cout << "RPBot: idle timed out\n"; + // TODO: use multiprompt and better system promp to start immediatly + if (auto* mreg = _rmm.get(c); mreg != nullptr && mreg->view().size() >= 4) { + to_remove_stateidle.push_back(c); + min_tick_interval = 0.1f; + + // transition to Next + emplaceStateTransition(_cr, c, state); + } else { + // check-in in another 15-45s + state.timeout = std::uniform_real_distribution<>{15.f, 45.f}(_rng); + std::cout << "RPBot: not ready yet, back to ideling\n"; + if (mreg == nullptr) { + std::cout << "mreg is null\n"; + } else { + std::cout << "size(): " << mreg->view().size() << "\n"; + } + } + } + }); + + _cr.remove(to_remove_stateidle.cbegin(), to_remove_stateidle.cend()); + return min_tick_interval; +} + +float RPBot::doAllNext(float) { + float min_tick_interval = std::numeric_limits::max(); + std::vector to_remove; + auto view = _cr.view(); + + view.each([this, &to_remove, &min_tick_interval](const Contact3 c, StateNext& state) { + // TODO: how to timeout? + if (state.future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { + to_remove.push_back(c); + min_tick_interval = 0.1f; + std::cout << "RPBot: next compute done!\n"; + + const auto selected = state.future.get(); + if (selected >= 0 && size_t(selected) < state.possible_names.size()) { + std::cout << "next is " << state.possible_names.at(selected) << "(" << selected << ")\n"; + if (_cr.all_of(state.possible_contacts.at(selected))) { + // we predicted ourselfs + emplaceStateTransition(_cr, c, state); + return; + } + } else { + std::cerr << "RPBot error: next was negative or too large (how?) " << selected << "\n"; + } + + // transition to Idle + emplaceStateTransition(_cr, c, state); + } + }); + + _cr.remove(to_remove.cbegin(), to_remove.cend()); + return min_tick_interval; +} + +float RPBot::doAllGenerateMsg(float) { + float min_tick_interval = std::numeric_limits::max(); + std::vector to_remove; + auto view = _cr.view(); + + _cr.remove(to_remove.cbegin(), to_remove.cend()); + view.each([this, &to_remove, &min_tick_interval](const Contact3 c, StateGenerateMsg& state) { + // TODO: how to timeout? + if (state.future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { + to_remove.push_back(c); + min_tick_interval = 0.1f; + std::cout << "RPBot: generatemessage compute done!\n"; + + std::string msg = state.future.get(); + _rmm.sendText(c, msg); + + // TODO: timing check? + // transition to Idle + emplaceStateTransition(_cr, c, state); + } + }); + return min_tick_interval; +} + +float RPBot::doAllTimingCheck(float time_delta) { + float min_tick_interval = std::numeric_limits::max(); + return min_tick_interval; } diff --git a/src/solanaceae/rpbot/rpbot.hpp b/src/solanaceae/rpbot/rpbot.hpp index ef469b7..594edd4 100644 --- a/src/solanaceae/rpbot/rpbot.hpp +++ b/src/solanaceae/rpbot/rpbot.hpp @@ -2,17 +2,62 @@ #include #include +#include +#include +#include + +#include +#include +#include + +// fwd +struct StateIdle; +struct StateNext; +struct StateGenerateMsg; +struct StateTimingCheck; struct RPBot { TextCompletionI& _completion; ConfigModelI& _conf; + Contact3Registry& _cr; + RegistryMessageModel& _rmm; + MessageCommandDispatcher* _mcd; + + std::minstd_rand _rng{std::random_device{}()}; + + std::string system_prompt; public: RPBot( TextCompletionI& completion, - ConfigModelI& conf + ConfigModelI& conf, + Contact3Registry& cr, + RegistryMessageModel& rmm, + MessageCommandDispatcher* mcd ); float tick(float time_delta); + + void registerCommands(void); + + protected: // state transitions + // all transitions need to be explicitly declared + template + void stateTransition(const Contact3 c, const From& from, To& to) = delete; + + // reg helper + template + To& emplaceStateTransition(Contact3Registry& cr, Contact3 c, const From& state) { + std::cout << "RPBot: transition from " << From::name << " to " << To::name << "\n"; + To& to = cr.emplace_or_replace(c); + stateTransition(c, state, to); + return to; + } + + protected: // systems + float doAllIdle(float time_delta); + float doAllNext(float time_delta); + float doAllGenerateMsg(float time_delta); + float doAllTimingCheck(float time_delta); };