174 lines
4.2 KiB
C++
174 lines
4.2 KiB
C++
#include "./llama_cpp_web_impl.hpp"
|
|
|
|
#include <solanaceae/util/utils.hpp>
|
|
|
|
#include <nlohmann/json.hpp>
|
|
|
|
#include <chrono>
|
|
|
|
std::minstd_rand thread_local LlamaCppWeb::_rng{std::random_device{}()};
|
|
|
|
// TODO: variant that strips unicode?
|
|
static std::string convertToSafeGrammarString(std::string_view input) {
|
|
std::string res;
|
|
for (const char c : input) {
|
|
res += "\\x";
|
|
res += bin2hex({static_cast<uint8_t>(c)});
|
|
}
|
|
return res;
|
|
}
|
|
|
|
LlamaCppWeb::LlamaCppWeb(
|
|
ConfigModelI& conf
|
|
) : _conf(conf), _cli(_conf.get_string("LlamaCppWeb", "server").value_or("localhost:8080")) {
|
|
}
|
|
|
|
LlamaCppWeb::~LlamaCppWeb(void) {
|
|
}
|
|
|
|
bool LlamaCppWeb::isGood(void) {
|
|
//return true; // HACK: test old llama server
|
|
auto res = _cli.Get("/health");
|
|
if (
|
|
res.error() != httplib::Error::Success ||
|
|
res->status != 200 ||
|
|
res->body.empty() ||
|
|
res->get_header_value("Content-Type") != "application/json"
|
|
) {
|
|
return false;
|
|
}
|
|
|
|
//std::cout << "/health code: " << res->status << " body: " << res->body << "\n";
|
|
//std::cout << "Content-Type: " << res->get_header_value("Content-Type") << "\n";
|
|
|
|
const auto response_body_j = nlohmann::json::parse(res->body, nullptr, false);
|
|
|
|
const std::string status = response_body_j.value("status", std::string{"value-not-found"});
|
|
if (status != "ok") {
|
|
std::cerr << "status not ok: " << status << "\n";
|
|
return false;
|
|
}
|
|
|
|
return true; // healthy endpoint
|
|
}
|
|
|
|
int64_t LlamaCppWeb::completeSelect(const std::string_view prompt, const std::vector<std::string_view>& possible) {
|
|
if (possible.empty()) {
|
|
return -1;
|
|
}
|
|
if (possible.size() == 1) {
|
|
return 0;
|
|
}
|
|
|
|
// see
|
|
// https://github.com/ggerganov/llama.cpp/tree/master/grammars#example
|
|
std::string grammar {"root ::= "};
|
|
bool first = true;
|
|
for (const auto& it : possible) {
|
|
if (first) {
|
|
first = false;
|
|
} else {
|
|
grammar += "| ";
|
|
}
|
|
grammar += "\"";
|
|
//grammar += it;
|
|
grammar += convertToSafeGrammarString(it);
|
|
grammar += "\" ";
|
|
}
|
|
//grammar += ")";
|
|
|
|
std::cerr << "generated grammar:\n" << grammar << "\n";
|
|
|
|
auto ret = complete(nlohmann::json{
|
|
{"prompt", prompt},
|
|
{"grammar", grammar},
|
|
{"min_p", 0.1}, // model dependent
|
|
{"repeat_penalty", 1.0}, // deactivate
|
|
{"temperature", 0.9}, // depends 1.0 for chat models
|
|
{"top_k", 60},
|
|
{"top_p", 1.0}, // disable
|
|
{"n_predict", 256}, // unlikely to ever be so high
|
|
{"seed", _rng()},
|
|
{"ignore_eos", true},
|
|
{"cache_prompt", static_cast<bool>(_use_server_cache)},
|
|
});
|
|
|
|
if (ret.empty()) {
|
|
assert("ret empty" && false);
|
|
return -2;
|
|
}
|
|
|
|
if (!ret.count("content")) {
|
|
assert("no content" && false);
|
|
return -3;
|
|
}
|
|
|
|
std::string selected = ret.at("content");
|
|
if (selected.empty()) {
|
|
assert("content empty" && false);
|
|
return -4;
|
|
}
|
|
|
|
for (int64_t i = 0; i < (int64_t)possible.size(); i++) {
|
|
if (selected == possible[i]) {
|
|
return i;
|
|
}
|
|
}
|
|
|
|
std::cerr << "content does not contain match\n";
|
|
std::cerr << "complete failed j:'" << ret.dump() << "'\n";
|
|
return -5;
|
|
}
|
|
|
|
std::string LlamaCppWeb::completeLine(const std::string_view prompt) {
|
|
auto ret = complete(nlohmann::json{
|
|
{"prompt", prompt},
|
|
{"min_p", 0.1}, // model dependent
|
|
{"repeat_penalty", 1.0}, // deactivate
|
|
{"temperature", 0.9}, // depends 1.0 for chat models
|
|
{"top_k", 60},
|
|
{"top_p", 1.0}, // disable
|
|
{"n_predict", 400},
|
|
{"seed", _rng()},
|
|
{"stop", nlohmann::json::array({"\n"})},
|
|
{"cache_prompt", static_cast<bool>(_use_server_cache)},
|
|
});
|
|
|
|
if (ret.empty() || ret.count("content") == 0) {
|
|
// error
|
|
return "";
|
|
}
|
|
|
|
return ret.at("content");
|
|
}
|
|
|
|
nlohmann::json LlamaCppWeb::complete(const nlohmann::json& request_j) {
|
|
// TODO: dont check ourself
|
|
if (!isGood()) {
|
|
return {};
|
|
}
|
|
|
|
// completions can take very long
|
|
// steaming instead would be better
|
|
_cli.set_read_timeout(std::chrono::minutes(10));
|
|
|
|
std::cerr << "j dump: '" << request_j.dump(-1, ' ', true) << "'\n";
|
|
|
|
auto res = _cli.Post("/completion", request_j.dump(-1, ' ', true), "application/json");
|
|
|
|
//std::cerr << "res.error():" << res.error() << "\n";
|
|
|
|
if (
|
|
res.error() != httplib::Error::Success ||
|
|
res->status != 200
|
|
//res->body.empty() ||
|
|
//res->get_header_value("Content-Type") != "application/json"
|
|
) {
|
|
std::cerr << "error posting: '" << res->body << "'\n";
|
|
return {};
|
|
}
|
|
|
|
return nlohmann::json::parse(res->body, nullptr, false);
|
|
}
|
|
|