From 784fea96d6a39337594041baf3481893c3c779bc Mon Sep 17 00:00:00 2001 From: Green Sky Date: Thu, 25 Jan 2024 13:45:58 +0100 Subject: [PATCH] general improvements, enable server prompt caching --- src/CMakeLists.txt | 1 + .../llama-cpp-web/llama_cpp_web_impl.cpp | 2 + .../llama-cpp-web/llama_cpp_web_impl.hpp | 3 + .../rpbot/message_prompt_builder.cpp | 96 +++++++++++++++++ .../rpbot/message_prompt_builder.hpp | 100 ++---------------- src/solanaceae/rpbot/rpbot.cpp | 3 +- 6 files changed, 112 insertions(+), 93 deletions(-) create mode 100644 src/solanaceae/rpbot/message_prompt_builder.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a4bda44..534427b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,6 +32,7 @@ target_link_libraries(test1 PUBLIC add_library(solanaceae_rpbot ./solanaceae/rpbot/message_prompt_builder.hpp + ./solanaceae/rpbot/message_prompt_builder.cpp ./solanaceae/rpbot/rpbot.hpp ./solanaceae/rpbot/rpbot.cpp diff --git a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp index ee30aff..b23cd59 100644 --- a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp +++ b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.cpp @@ -83,6 +83,7 @@ int64_t LlamaCppWeb::completeSelect(const std::string_view prompt, const std::ve {"top_p", 1.0}, // disable {"n_predict", 256}, // unlikely to ever be so high {"seed", _rng()}, + {"cache_prompt", static_cast(_use_server_cache)}, }); if (ret.empty()) { @@ -119,6 +120,7 @@ std::string LlamaCppWeb::completeLine(const std::string_view prompt) { {"n_predict", 1000}, {"seed", _rng()}, {"stop", {"\n"}}, + {"cache_prompt", static_cast(_use_server_cache)}, }); if (ret.empty() || ret.count("content") == 0) { diff --git a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp index 7619130..485d644 100644 --- a/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp +++ b/src/solanaceae/llama-cpp-web/llama_cpp_web_impl.hpp @@ -6,6 +6,7 @@ #include #include +#include struct LlamaCppWeb : public TextCompletionI { // this mutex locks internally @@ -14,6 +15,8 @@ struct LlamaCppWeb : public TextCompletionI { // this is a bad idea static std::minstd_rand thread_local _rng; + std::atomic _use_server_cache {true}; + ~LlamaCppWeb(void); bool isGood(void) override; diff --git a/src/solanaceae/rpbot/message_prompt_builder.cpp b/src/solanaceae/rpbot/message_prompt_builder.cpp new file mode 100644 index 0000000..dd091b2 --- /dev/null +++ b/src/solanaceae/rpbot/message_prompt_builder.cpp @@ -0,0 +1,96 @@ +#include "./message_prompt_builder.hpp" + +#include "./rpbot.hpp" + +#include +#include + +bool MessagePromptBuilder::buildNameLookup(void) { + if (_cr.all_of(_c)) { // group rpbot + const auto& subs = _cr.get(_c).subs; + // should include self + for (const auto sub_c : subs) { + if (_cr.all_of(sub_c)) { + names[sub_c] = _cr.get(sub_c).name; + } + } + } else { // pm rpbot + if (_cr.all_of(_c)) { + names[_c] = _cr.get(_c).name; + } else { + std::cerr << "RPBot error: other missing name\n"; + return false; + } + + if (_cr.all_of(_c)) { + const auto self = _cr.get(_c).self; + if (_cr.all_of(self)) { + names[self] = _cr.get(self).name; + } else { + std::cerr << "RPBot error: self missing name\n"; + return false; + } + } else { + std::cerr << "RPBot error: contact missing self\n"; + return false; + } + } + + return true; +} + +std::string MessagePromptBuilder::buildPromptMessageHistory(void) { + auto* mr = _rmm.get(_c); + assert(mr); + + std::string prompt; + + auto view = mr->view(); + for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) { + const Message3 e = *view_it; + + // manually filter ("reverse" iteration <.<) + // TODO: add mesagetext? + if (!mr->all_of(e)) { + continue; + } + + //Message::Components::ContactFrom& c_from = mr->get(e); + //Message::Components::ContactTo& c_to = msg_reg.get(e); + //Message::Components::Timestamp ts = view.get(e); + + prompt += "\n"; + prompt += buildPromptMessage({*mr, e}); + } + + return prompt; +} + +std::string MessagePromptBuilder::buildPromptMessage(const Message3Handle m) { + if (!m.all_of()) { + // TODO: case for transfers + return ""; + } + + // TODO: cache as comp + const auto line_prefix = promptMessagePrefixSimple(m); + + // TODO: trim + std::string message_lines = line_prefix + m.get().text; + for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) { + message_lines.insert(nlpos+1, line_prefix); + nlpos += line_prefix.size(); + } + + return message_lines; +} + +std::string MessagePromptBuilder::promptMessagePrefixSimple(const Message3Handle m) { + const Contact3 from = m.get().c; + if (names.count(from)) { + return std::string{names[from]} + ": "; + } else { + return ": "; + } +} + diff --git a/src/solanaceae/rpbot/message_prompt_builder.hpp b/src/solanaceae/rpbot/message_prompt_builder.hpp index 290d4cf..b671007 100644 --- a/src/solanaceae/rpbot/message_prompt_builder.hpp +++ b/src/solanaceae/rpbot/message_prompt_builder.hpp @@ -1,12 +1,10 @@ #pragma once -#include "./rpbot.hpp" - -#include -#include +#include +#include +#include #include -#include // TODO: improve caching struct MessagePromptBuilder { @@ -17,96 +15,14 @@ struct MessagePromptBuilder { // lookup table, string_view since no name-components are changed entt::dense_map names; + bool buildNameLookup(void); - bool buildNameLookup(void) { - if (_cr.all_of(_c)) { // group rpbot - const auto& subs = _cr.get(_c).subs; - // should include self - for (const auto sub_c : subs) { - if (_cr.all_of(sub_c)) { - names[sub_c] = _cr.get(sub_c).name; - } - } - } else { // pm rpbot - if (_cr.all_of(_c)) { - names[_c] = _cr.get(_c).name; - } else { - std::cerr << "RPBot error: other missing name\n"; - return false; - } - - if (_cr.all_of(_c)) { - const auto self = _cr.get(_c).self; - if (_cr.all_of(self)) { - names[self] = _cr.get(self).name; - } else { - std::cerr << "RPBot error: self missing name\n"; - return false; - } - } else { - std::cerr << "RPBot error: contact missing self\n"; - return false; - } - } - - return true; - } - - std::string buildPromptMessageHistory(void) { - auto* mr = _rmm.get(_c); - assert(mr); - - std::string prompt; - - auto view = mr->view(); - for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) { - const Message3 e = *view_it; - - // manually filter ("reverse" iteration <.<) - // TODO: add mesagetext? - if (!mr->all_of(e)) { - continue; - } - - //Message::Components::ContactFrom& c_from = mr->get(e); - //Message::Components::ContactTo& c_to = msg_reg.get(e); - //Message::Components::Timestamp ts = view.get(e); - - prompt += "\n"; - prompt += buildPromptMessage({*mr, e}); - } - - return prompt; - } + std::string buildPromptMessageHistory(void); // gets split across lines - std::string buildPromptMessage(const Message3Handle m) { - if (!m.all_of()) { - // TODO: case for transfers - return ""; - } + std::string buildPromptMessage(const Message3Handle m); - // TODO: cache as comp - const auto line_prefix = buildPromptMessagePrefix(m); - - // TODO: trim - std::string message_lines = line_prefix + m.get().text; - for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) { - message_lines.insert(nlpos+1, line_prefix); - nlpos += line_prefix.size(); - } - - return message_lines; - } - - // generate prompt prefix (timestamp + name), configurable? - std::string buildPromptMessagePrefix(const Message3Handle m) { - const Contact3 from = m.get().c; - if (names.count(from)) { - return std::string{names[from]} + ": "; - } else { - return ": "; - } - } + // generate prompt prefix (just "name:") + std::string promptMessagePrefixSimple(const Message3Handle m); }; diff --git a/src/solanaceae/rpbot/rpbot.cpp b/src/solanaceae/rpbot/rpbot.cpp index 1d3b9ff..ed9e47d 100644 --- a/src/solanaceae/rpbot/rpbot.cpp +++ b/src/solanaceae/rpbot/rpbot.cpp @@ -122,7 +122,8 @@ void RPBot::stateTransition(const Contact3 c, const StateNext& from, StateGenera template<> void RPBot::stateTransition(const Contact3, const StateGenerateMsg&, StateIdle& to) { - to.timeout = std::uniform_real_distribution<>{5.f, 20.f}(_rng); + // relativly slow delay for multi line messages + to.timeout = std::uniform_real_distribution<>{2.f, 15.f}(_rng); } RPBot::RPBot(