general improvements, enable server prompt caching
This commit is contained in:
		| @@ -32,6 +32,7 @@ target_link_libraries(test1 PUBLIC | ||||
|  | ||||
| add_library(solanaceae_rpbot | ||||
| 	./solanaceae/rpbot/message_prompt_builder.hpp | ||||
| 	./solanaceae/rpbot/message_prompt_builder.cpp | ||||
|  | ||||
| 	./solanaceae/rpbot/rpbot.hpp | ||||
| 	./solanaceae/rpbot/rpbot.cpp | ||||
|   | ||||
| @@ -83,6 +83,7 @@ int64_t LlamaCppWeb::completeSelect(const std::string_view prompt, const std::ve | ||||
| 		{"top_p", 1.0}, // disable | ||||
| 		{"n_predict", 256}, // unlikely to ever be so high | ||||
| 		{"seed", _rng()}, | ||||
| 		{"cache_prompt", static_cast<bool>(_use_server_cache)}, | ||||
| 	}); | ||||
|  | ||||
| 	if (ret.empty()) { | ||||
| @@ -119,6 +120,7 @@ std::string LlamaCppWeb::completeLine(const std::string_view prompt) { | ||||
| 		{"n_predict", 1000}, | ||||
| 		{"seed", _rng()}, | ||||
| 		{"stop", {"\n"}}, | ||||
| 		{"cache_prompt", static_cast<bool>(_use_server_cache)}, | ||||
| 	}); | ||||
|  | ||||
| 	if (ret.empty() || ret.count("content") == 0) { | ||||
|   | ||||
| @@ -6,6 +6,7 @@ | ||||
| #include <nlohmann/json_fwd.hpp> | ||||
|  | ||||
| #include <random> | ||||
| #include <atomic> | ||||
|  | ||||
| struct LlamaCppWeb : public TextCompletionI { | ||||
| 	// this mutex locks internally | ||||
| @@ -14,6 +15,8 @@ struct LlamaCppWeb : public TextCompletionI { | ||||
| 	// this is a bad idea | ||||
| 	static std::minstd_rand thread_local _rng; | ||||
|  | ||||
| 	std::atomic<bool> _use_server_cache {true}; | ||||
|  | ||||
| 	~LlamaCppWeb(void); | ||||
|  | ||||
| 	bool isGood(void) override; | ||||
|   | ||||
							
								
								
									
										96
									
								
								src/solanaceae/rpbot/message_prompt_builder.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								src/solanaceae/rpbot/message_prompt_builder.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | ||||
| #include "./message_prompt_builder.hpp" | ||||
|  | ||||
| #include "./rpbot.hpp" | ||||
|  | ||||
| #include <solanaceae/contact/components.hpp> | ||||
| #include <solanaceae/message3/components.hpp> | ||||
|  | ||||
| bool MessagePromptBuilder::buildNameLookup(void) { | ||||
| 	if (_cr.all_of<Contact::Components::ParentOf>(_c)) { // group rpbot | ||||
| 		const auto& subs = _cr.get<Contact::Components::ParentOf>(_c).subs; | ||||
| 		// should include self | ||||
| 		for (const auto sub_c : subs) { | ||||
| 			if (_cr.all_of<Contact::Components::Name>(sub_c)) { | ||||
| 				names[sub_c] = _cr.get<Contact::Components::Name>(sub_c).name; | ||||
| 			} | ||||
| 		} | ||||
| 	} else { // pm rpbot | ||||
| 		if (_cr.all_of<Contact::Components::Name>(_c)) { | ||||
| 			names[_c] = _cr.get<Contact::Components::Name>(_c).name; | ||||
| 		} else { | ||||
| 			std::cerr << "RPBot error: other missing name\n"; | ||||
| 			return false; | ||||
| 		} | ||||
|  | ||||
| 		if (_cr.all_of<Contact::Components::Self>(_c)) { | ||||
| 			const auto self = _cr.get<Contact::Components::Self>(_c).self; | ||||
| 			if (_cr.all_of<Contact::Components::Name>(self)) { | ||||
| 				names[self] = _cr.get<Contact::Components::Name>(self).name; | ||||
| 			} else { | ||||
| 				std::cerr << "RPBot error: self missing name\n"; | ||||
| 				return false; | ||||
| 			} | ||||
| 		} else { | ||||
| 			std::cerr << "RPBot error: contact missing self\n"; | ||||
| 			return false; | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return true; | ||||
| } | ||||
|  | ||||
| std::string MessagePromptBuilder::buildPromptMessageHistory(void) { | ||||
| 	auto* mr =  _rmm.get(_c); | ||||
| 	assert(mr); | ||||
|  | ||||
| 	std::string prompt; | ||||
|  | ||||
| 	auto view = mr->view<Message::Components::Timestamp>(); | ||||
| 	for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) { | ||||
| 		const Message3 e = *view_it; | ||||
|  | ||||
| 		// manually filter ("reverse" iteration <.<) | ||||
| 		// TODO: add mesagetext? | ||||
| 		if (!mr->all_of<Message::Components::ContactFrom, Message::Components::ContactTo>(e)) { | ||||
| 			continue; | ||||
| 		} | ||||
|  | ||||
| 		//Message::Components::ContactFrom& c_from = mr->get<Message::Components::ContactFrom>(e); | ||||
| 		//Message::Components::ContactTo& c_to = msg_reg.get<Message::Components::ContactTo>(e); | ||||
| 		//Message::Components::Timestamp ts = view.get<Message::Components::Timestamp>(e); | ||||
|  | ||||
| 		prompt += "\n"; | ||||
| 		prompt += buildPromptMessage({*mr, e}); | ||||
| 	} | ||||
|  | ||||
| 	return prompt; | ||||
| } | ||||
|  | ||||
| std::string MessagePromptBuilder::buildPromptMessage(const Message3Handle m) { | ||||
| 	if (!m.all_of<Message::Components::ContactFrom, Message::Components::MessageText>()) { | ||||
| 		// TODO: case for transfers | ||||
| 		return ""; | ||||
| 	} | ||||
|  | ||||
| 	// TODO: cache as comp | ||||
| 	const auto line_prefix = promptMessagePrefixSimple(m); | ||||
|  | ||||
| 	// TODO: trim | ||||
| 	std::string message_lines = line_prefix + m.get<Message::Components::MessageText>().text; | ||||
| 	for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) { | ||||
| 		message_lines.insert(nlpos+1, line_prefix); | ||||
| 		nlpos += line_prefix.size(); | ||||
| 	} | ||||
|  | ||||
| 	return message_lines; | ||||
| } | ||||
|  | ||||
| std::string MessagePromptBuilder::promptMessagePrefixSimple(const Message3Handle m) { | ||||
| 	const Contact3 from = m.get<Message::Components::ContactFrom>().c; | ||||
| 	if (names.count(from)) { | ||||
| 		return std::string{names[from]} + ": "; | ||||
| 	} else { | ||||
| 		return "<unk-user>: "; | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -1,12 +1,10 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "./rpbot.hpp" | ||||
|  | ||||
| #include <solanaceae/contact/components.hpp> | ||||
| #include <solanaceae/message3/components.hpp> | ||||
| #include <solanaceae/util/config_model.hpp> | ||||
| #include <solanaceae/contact/contact_model3.hpp> | ||||
| #include <solanaceae/message3/registry_message_model.hpp> | ||||
|  | ||||
| #include <entt/container/dense_map.hpp> | ||||
| #include <entt/container/dense_set.hpp> | ||||
|  | ||||
| // TODO: improve caching | ||||
| struct MessagePromptBuilder { | ||||
| @@ -17,96 +15,14 @@ struct MessagePromptBuilder { | ||||
| 	// lookup table, string_view since no name-components are changed | ||||
| 	entt::dense_map<Contact3, std::string_view> names; | ||||
|  | ||||
| 	bool buildNameLookup(void); | ||||
|  | ||||
| 	bool buildNameLookup(void) { | ||||
| 		if (_cr.all_of<Contact::Components::ParentOf>(_c)) { // group rpbot | ||||
| 			const auto& subs = _cr.get<Contact::Components::ParentOf>(_c).subs; | ||||
| 			// should include self | ||||
| 			for (const auto sub_c : subs) { | ||||
| 				if (_cr.all_of<Contact::Components::Name>(sub_c)) { | ||||
| 					names[sub_c] = _cr.get<Contact::Components::Name>(sub_c).name; | ||||
| 				} | ||||
| 			} | ||||
| 		} else { // pm rpbot | ||||
| 			if (_cr.all_of<Contact::Components::Name>(_c)) { | ||||
| 				names[_c] = _cr.get<Contact::Components::Name>(_c).name; | ||||
| 			} else { | ||||
| 				std::cerr << "RPBot error: other missing name\n"; | ||||
| 				return false; | ||||
| 			} | ||||
|  | ||||
| 			if (_cr.all_of<Contact::Components::Self>(_c)) { | ||||
| 				const auto self = _cr.get<Contact::Components::Self>(_c).self; | ||||
| 				if (_cr.all_of<Contact::Components::Name>(self)) { | ||||
| 					names[self] = _cr.get<Contact::Components::Name>(self).name; | ||||
| 				} else { | ||||
| 					std::cerr << "RPBot error: self missing name\n"; | ||||
| 					return false; | ||||
| 				} | ||||
| 			} else { | ||||
| 				std::cerr << "RPBot error: contact missing self\n"; | ||||
| 				return false; | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return true; | ||||
| 	} | ||||
|  | ||||
| 	std::string buildPromptMessageHistory(void) { | ||||
| 		auto* mr =  _rmm.get(_c); | ||||
| 		assert(mr); | ||||
|  | ||||
| 		std::string prompt; | ||||
|  | ||||
| 		auto view = mr->view<Message::Components::Timestamp>(); | ||||
| 		for (auto view_it = view.rbegin(), view_last = view.rend(); view_it != view_last; view_it++) { | ||||
| 			const Message3 e = *view_it; | ||||
|  | ||||
| 			// manually filter ("reverse" iteration <.<) | ||||
| 			// TODO: add mesagetext? | ||||
| 			if (!mr->all_of<Message::Components::ContactFrom, Message::Components::ContactTo>(e)) { | ||||
| 				continue; | ||||
| 			} | ||||
|  | ||||
| 			//Message::Components::ContactFrom& c_from = mr->get<Message::Components::ContactFrom>(e); | ||||
| 			//Message::Components::ContactTo& c_to = msg_reg.get<Message::Components::ContactTo>(e); | ||||
| 			//Message::Components::Timestamp ts = view.get<Message::Components::Timestamp>(e); | ||||
|  | ||||
| 			prompt += "\n"; | ||||
| 			prompt += buildPromptMessage({*mr, e}); | ||||
| 		} | ||||
|  | ||||
| 		return prompt; | ||||
| 	} | ||||
| 	std::string buildPromptMessageHistory(void); | ||||
|  | ||||
| 	// gets split across lines | ||||
| 	std::string buildPromptMessage(const Message3Handle m) { | ||||
| 		if (!m.all_of<Message::Components::ContactFrom, Message::Components::MessageText>()) { | ||||
| 			// TODO: case for transfers | ||||
| 			return ""; | ||||
| 		} | ||||
| 	std::string buildPromptMessage(const Message3Handle m); | ||||
|  | ||||
| 		// TODO: cache as comp | ||||
| 		const auto line_prefix = buildPromptMessagePrefix(m); | ||||
|  | ||||
| 		// TODO: trim | ||||
| 		std::string message_lines = line_prefix + m.get<Message::Components::MessageText>().text; | ||||
| 		for (auto nlpos = message_lines.find('\n'); nlpos != std::string::npos; nlpos = message_lines.find('\n', nlpos+1)) { | ||||
| 			message_lines.insert(nlpos+1, line_prefix); | ||||
| 			nlpos += line_prefix.size(); | ||||
| 		} | ||||
|  | ||||
| 		return message_lines; | ||||
| 	} | ||||
|  | ||||
| 	// generate prompt prefix (timestamp + name), configurable? | ||||
| 	std::string buildPromptMessagePrefix(const Message3Handle m) { | ||||
| 		const Contact3 from = m.get<Message::Components::ContactFrom>().c; | ||||
| 		if (names.count(from)) { | ||||
| 			return std::string{names[from]} + ": "; | ||||
| 		} else { | ||||
| 			return "<unk-user>: "; | ||||
| 		} | ||||
| 	} | ||||
| 	// generate prompt prefix (just "name:") | ||||
| 	std::string promptMessagePrefixSimple(const Message3Handle m); | ||||
| }; | ||||
|  | ||||
|   | ||||
| @@ -122,7 +122,8 @@ void RPBot::stateTransition(const Contact3 c, const StateNext& from, StateGenera | ||||
|  | ||||
| template<> | ||||
| void RPBot::stateTransition(const Contact3, const StateGenerateMsg&, StateIdle& to) { | ||||
| 	to.timeout = std::uniform_real_distribution<>{5.f, 20.f}(_rng); | ||||
| 	// relativly slow delay for multi line messages | ||||
| 	to.timeout = std::uniform_real_distribution<>{2.f, 15.f}(_rng); | ||||
| } | ||||
|  | ||||
| RPBot::RPBot( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user