general improvements, enable server prompt caching
This commit is contained in:
parent
f5650475c7
commit
784fea96d6
@ -32,6 +32,7 @@ target_link_libraries(test1 PUBLIC
|
|||||||
|
|
||||||
add_library(solanaceae_rpbot
|
add_library(solanaceae_rpbot
|
||||||
./solanaceae/rpbot/message_prompt_builder.hpp
|
./solanaceae/rpbot/message_prompt_builder.hpp
|
||||||
|
./solanaceae/rpbot/message_prompt_builder.cpp
|
||||||
|
|
||||||
./solanaceae/rpbot/rpbot.hpp
|
./solanaceae/rpbot/rpbot.hpp
|
||||||
./solanaceae/rpbot/rpbot.cpp
|
./solanaceae/rpbot/rpbot.cpp
|
||||||
|
@ -83,6 +83,7 @@ int64_t LlamaCppWeb::completeSelect(const std::string_view prompt, const std::ve
|
|||||||
{"top_p", 1.0}, // disable
|
{"top_p", 1.0}, // disable
|
||||||
{"n_predict", 256}, // unlikely to ever be so high
|
{"n_predict", 256}, // unlikely to ever be so high
|
||||||
{"seed", _rng()},
|
{"seed", _rng()},
|
||||||
|
{"cache_prompt", static_cast<bool>(_use_server_cache)},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (ret.empty()) {
|
if (ret.empty()) {
|
||||||
@ -119,6 +120,7 @@ std::string LlamaCppWeb::completeLine(const std::string_view prompt) {
|
|||||||
{"n_predict", 1000},
|
{"n_predict", 1000},
|
||||||
{"seed", _rng()},
|
{"seed", _rng()},
|
||||||
{"stop", {"\n"}},
|
{"stop", {"\n"}},
|
||||||
|
{"cache_prompt", static_cast<bool>(_use_server_cache)},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (ret.empty() || ret.count("content") == 0) {
|
if (ret.empty() || ret.count("content") == 0) {
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <nlohmann/json_fwd.hpp>
|
#include <nlohmann/json_fwd.hpp>
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
struct LlamaCppWeb : public TextCompletionI {
|
struct LlamaCppWeb : public TextCompletionI {
|
||||||
// this mutex locks internally
|
// this mutex locks internally
|
||||||
@ -14,6 +15,8 @@ struct LlamaCppWeb : public TextCompletionI {
|
|||||||
// this is a bad idea
|
// this is a bad idea
|
||||||
static std::minstd_rand thread_local _rng;
|
static std::minstd_rand thread_local _rng;
|
||||||
|
|
||||||
|
std::atomic<bool> _use_server_cache {true};
|
||||||
|
|
||||||
~LlamaCppWeb(void);
|
~LlamaCppWeb(void);
|
||||||
|
|
||||||
bool isGood(void) override;
|
bool isGood(void) override;
|
||||||
|
96
src/solanaceae/rpbot/message_prompt_builder.cpp
Normal file
96
src/solanaceae/rpbot/message_prompt_builder.cpp
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#include "./message_prompt_builder.hpp"
|
||||||
|
|
||||||
|
#include "./rpbot.hpp"
|
||||||
|
|
||||||
|
#include <solanaceae/contact/components.hpp>
|
||||||
|
#include <solanaceae/message3/components.hpp>
|
||||||
|
|
||||||
|
bool MessagePromptBuilder::buildNameLookup(void) {
|
||||||
|
if (_cr.all_of<Contact::Components::ParentOf>(_c)) { // group rpbot
|
||||||
|
const auto& subs = _cr.get<Contact::Components::ParentOf>(_c).subs;
|
||||||
|
// should include self
|
||||||
|
for (const auto sub_c : subs) {
|
||||||
|
if (_cr.all_of<Contact::Components::Name>(sub_c)) {
|
||||||
|
names[sub_c] = _cr.get<Contact::Components::Name>(sub_c).name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { // pm rpbot
|
||||||
|
if (_cr.all_of<Contact::Components::Name>(_c)) {
|
||||||
|
names[_c] = _cr.get<Contact::Components::Name>(_c).name;
|
||||||
|
} else {
|
||||||
|
std::cerr << "RPBot error: other missing name\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_cr.all_of<Contact::Components::Self>(_c)) {
|
||||||
|
const auto self = _cr.get<Contact::Components::Self>(_c).self;
|
||||||
|
if (_cr.all_of<Contact::Components::Name>(self)) {
|
||||||
|
names[self] = _cr.get<Contact::Components::Name>(self).name;
|
||||||
|
} else {
|
||||||
|
std::cerr << "RPBot error: self missing name\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::cerr << "RPBot error: contact missing self\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string MessagePromptBuilder::buildPromptMessageHistory(void) {
|
||||||
|
auto* mr = _rmm.get(_c);
|
||||||
|
assert(mr);
|
||||||
|
|
||||||
|
std::string prompt;
|
||||||
|
|
||||||
|
auto view = mr->view<Message::Components::Timestamp>();
|
||||||
|
for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) {
|
||||||
|
const Message3 e = *view_it;
|
||||||
|
|
||||||
|
// manually filter ("reverse" iteration <.<)
|
||||||
|
// TODO: add mesagetext?
|
||||||
|
if (!mr->all_of<Message::Components::ContactFrom, Message::Components::ContactTo>(e)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Message::Components::ContactFrom& c_from = mr->get<Message::Components::ContactFrom>(e);
|
||||||
|
//Message::Components::ContactTo& c_to = msg_reg.get<Message::Components::ContactTo>(e);
|
||||||
|
//Message::Components::Timestamp ts = view.get<Message::Components::Timestamp>(e);
|
||||||
|
|
||||||
|
prompt += "\n";
|
||||||
|
prompt += buildPromptMessage({*mr, e});
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string MessagePromptBuilder::buildPromptMessage(const Message3Handle m) {
|
||||||
|
if (!m.all_of<Message::Components::ContactFrom, Message::Components::MessageText>()) {
|
||||||
|
// TODO: case for transfers
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: cache as comp
|
||||||
|
const auto line_prefix = promptMessagePrefixSimple(m);
|
||||||
|
|
||||||
|
// TODO: trim
|
||||||
|
std::string message_lines = line_prefix + m.get<Message::Components::MessageText>().text;
|
||||||
|
for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) {
|
||||||
|
message_lines.insert(nlpos+1, line_prefix);
|
||||||
|
nlpos += line_prefix.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return message_lines;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string MessagePromptBuilder::promptMessagePrefixSimple(const Message3Handle m) {
|
||||||
|
const Contact3 from = m.get<Message::Components::ContactFrom>().c;
|
||||||
|
if (names.count(from)) {
|
||||||
|
return std::string{names[from]} + ": ";
|
||||||
|
} else {
|
||||||
|
return "<unk-user>: ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,12 +1,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "./rpbot.hpp"
|
#include <solanaceae/util/config_model.hpp>
|
||||||
|
#include <solanaceae/contact/contact_model3.hpp>
|
||||||
#include <solanaceae/contact/components.hpp>
|
#include <solanaceae/message3/registry_message_model.hpp>
|
||||||
#include <solanaceae/message3/components.hpp>
|
|
||||||
|
|
||||||
#include <entt/container/dense_map.hpp>
|
#include <entt/container/dense_map.hpp>
|
||||||
#include <entt/container/dense_set.hpp>
|
|
||||||
|
|
||||||
// TODO: improve caching
|
// TODO: improve caching
|
||||||
struct MessagePromptBuilder {
|
struct MessagePromptBuilder {
|
||||||
@ -17,96 +15,14 @@ struct MessagePromptBuilder {
|
|||||||
// lookup table, string_view since no name-components are changed
|
// lookup table, string_view since no name-components are changed
|
||||||
entt::dense_map<Contact3, std::string_view> names;
|
entt::dense_map<Contact3, std::string_view> names;
|
||||||
|
|
||||||
|
bool buildNameLookup(void);
|
||||||
|
|
||||||
bool buildNameLookup(void) {
|
std::string buildPromptMessageHistory(void);
|
||||||
if (_cr.all_of<Contact::Components::ParentOf>(_c)) { // group rpbot
|
|
||||||
const auto& subs = _cr.get<Contact::Components::ParentOf>(_c).subs;
|
|
||||||
// should include self
|
|
||||||
for (const auto sub_c : subs) {
|
|
||||||
if (_cr.all_of<Contact::Components::Name>(sub_c)) {
|
|
||||||
names[sub_c] = _cr.get<Contact::Components::Name>(sub_c).name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else { // pm rpbot
|
|
||||||
if (_cr.all_of<Contact::Components::Name>(_c)) {
|
|
||||||
names[_c] = _cr.get<Contact::Components::Name>(_c).name;
|
|
||||||
} else {
|
|
||||||
std::cerr << "RPBot error: other missing name\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (_cr.all_of<Contact::Components::Self>(_c)) {
|
|
||||||
const auto self = _cr.get<Contact::Components::Self>(_c).self;
|
|
||||||
if (_cr.all_of<Contact::Components::Name>(self)) {
|
|
||||||
names[self] = _cr.get<Contact::Components::Name>(self).name;
|
|
||||||
} else {
|
|
||||||
std::cerr << "RPBot error: self missing name\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::cerr << "RPBot error: contact missing self\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string buildPromptMessageHistory(void) {
|
|
||||||
auto* mr = _rmm.get(_c);
|
|
||||||
assert(mr);
|
|
||||||
|
|
||||||
std::string prompt;
|
|
||||||
|
|
||||||
auto view = mr->view<Message::Components::Timestamp>();
|
|
||||||
for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) {
|
|
||||||
const Message3 e = *view_it;
|
|
||||||
|
|
||||||
// manually filter ("reverse" iteration <.<)
|
|
||||||
// TODO: add mesagetext?
|
|
||||||
if (!mr->all_of<Message::Components::ContactFrom, Message::Components::ContactTo>(e)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
//Message::Components::ContactFrom& c_from = mr->get<Message::Components::ContactFrom>(e);
|
|
||||||
//Message::Components::ContactTo& c_to = msg_reg.get<Message::Components::ContactTo>(e);
|
|
||||||
//Message::Components::Timestamp ts = view.get<Message::Components::Timestamp>(e);
|
|
||||||
|
|
||||||
prompt += "\n";
|
|
||||||
prompt += buildPromptMessage({*mr, e});
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
// gets split across lines
|
// gets split across lines
|
||||||
std::string buildPromptMessage(const Message3Handle m) {
|
std::string buildPromptMessage(const Message3Handle m);
|
||||||
if (!m.all_of<Message::Components::ContactFrom, Message::Components::MessageText>()) {
|
|
||||||
// TODO: case for transfers
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: cache as comp
|
// generate prompt prefix (just "name:")
|
||||||
const auto line_prefix = buildPromptMessagePrefix(m);
|
std::string promptMessagePrefixSimple(const Message3Handle m);
|
||||||
|
|
||||||
// TODO: trim
|
|
||||||
std::string message_lines = line_prefix + m.get<Message::Components::MessageText>().text;
|
|
||||||
for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) {
|
|
||||||
message_lines.insert(nlpos+1, line_prefix);
|
|
||||||
nlpos += line_prefix.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
return message_lines;
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate prompt prefix (timestamp + name), configurable?
|
|
||||||
std::string buildPromptMessagePrefix(const Message3Handle m) {
|
|
||||||
const Contact3 from = m.get<Message::Components::ContactFrom>().c;
|
|
||||||
if (names.count(from)) {
|
|
||||||
return std::string{names[from]} + ": ";
|
|
||||||
} else {
|
|
||||||
return "<unk-user>: ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -122,7 +122,8 @@ void RPBot::stateTransition(const Contact3 c, const StateNext& from, StateGenera
|
|||||||
|
|
||||||
template<>
|
template<>
|
||||||
void RPBot::stateTransition(const Contact3, const StateGenerateMsg&, StateIdle& to) {
|
void RPBot::stateTransition(const Contact3, const StateGenerateMsg&, StateIdle& to) {
|
||||||
to.timeout = std::uniform_real_distribution<>{5.f, 20.f}(_rng);
|
// relativly slow delay for multi line messages
|
||||||
|
to.timeout = std::uniform_real_distribution<>{2.f, 15.f}(_rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
RPBot::RPBot(
|
RPBot::RPBot(
|
||||||
|
Loading…
Reference in New Issue
Block a user