Skip to content

Commit 8521739

Browse files
committed
Store conversation using llama.cpp format
1 parent 2d2b4a8 commit 8521739

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

libraries/YarpPlugins/LlamaGPT/DeviceDriverImpl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,15 @@ bool LlamaGPT::open(yarp::os::Searchable & config)
106106

107107
bool LlamaGPT::close()
108108
{
109+
bool ret = deleteConversation();
110+
109111
if (model)
110112
{
111113
llama_model_free(model);
112114
model = nullptr;
113115
}
114116

115-
return true;
117+
return ret;
116118
}
117119

118120
// -----------------------------------------------------------------------------

libraries/YarpPlugins/LlamaGPT/ILLMImpl.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
#include "LlamaGPT.hpp"
44

55
#include <cctype> // std::isspace
6+
#include <cstdlib> // std::free
7+
#include <cstring> // ::strdup (POSIX standard, but not C standard)
68

7-
#include <algorithm> // std::find_if
9+
#include <algorithm> // std::find_if, std::transform
10+
#include <iterator> // std::back_inserter
811

912
#include <yarp/os/LogStream.h>
1013

@@ -62,7 +65,7 @@ bool LlamaGPT::setPrompt(const std::string & prompt)
6265

6366
yCInfo(LLAMA) << "Setting prompt:" << temp;
6467
m_prompt = temp;
65-
conversation.push_back(yarp::dev::LLM_Message("system", m_prompt, {}, {}));
68+
conversation.push_back({"system", ::strdup(m_prompt.c_str())});
6669

6770
#if YARP_VERSION_COMPARE(>=, 3, 12, 0)
6871
return yarp::dev::ReturnValue::return_code::return_value_ok;
@@ -96,7 +99,7 @@ bool LlamaGPT::ask(const std::string & question, yarp::dev::LLM_Message & answer
9699
#endif
97100
{
98101
yCInfo(LLAMA) << "Asking:" << question;
99-
conversation.push_back(yarp::dev::LLM_Message("user", question, {}, {}));
102+
conversation.push_back({"user", ::strdup(question.c_str())});
100103

101104
auto prompt = m_prompt;
102105

@@ -224,7 +227,7 @@ bool LlamaGPT::ask(const std::string & question, yarp::dev::LLM_Message & answer
224227
}
225228

226229
answer = {"assistant", out, {}, {}};
227-
conversation.push_back(answer);
230+
conversation.push_back({"assistant", ::strdup(out.c_str())});
228231

229232
#if YARP_VERSION_COMPARE(>=, 3, 12, 0)
230233
return yarp::dev::ReturnValue::return_code::return_value_ok;
@@ -241,7 +244,15 @@ yarp::dev::ReturnValue LlamaGPT::getConversation(std::vector<yarp::dev::LLM_Mess
241244
bool LlamaGPT::getConversation(std::vector<yarp::dev::LLM_Message> & conversation)
242245
#endif
243246
{
244-
conversation = std::vector(this->conversation);
247+
conversation.clear();
248+
conversation.reserve(this->conversation.size());
249+
250+
std::transform(this->conversation.cbegin(), this->conversation.cend(), std::back_inserter(conversation),
251+
[](const llama_chat_message & msg)
252+
{
253+
return yarp::dev::LLM_Message(msg.role, msg.content, {}, {});
254+
});
255+
245256
#if YARP_VERSION_COMPARE(>=, 3, 12, 0)
246257
return yarp::dev::ReturnValue::return_code::return_value_ok;
247258
#else
@@ -258,6 +269,12 @@ bool LlamaGPT::deleteConversation()
258269
#endif
259270
{
260271
yCInfo(LLAMA) << "Deleting conversation and prompt";
272+
273+
for (auto & msg : conversation)
274+
{
275+
std::free(const_cast<char *>(msg.content));
276+
}
277+
261278
conversation.clear();
262279
m_prompt.clear();
263280
#if YARP_VERSION_COMPARE(>=, 3, 12, 0)
@@ -276,6 +293,12 @@ bool LlamaGPT::refreshConversation()
276293
#endif
277294
{
278295
yCInfo(LLAMA) << "Deleting conversation while keeping the prompt";
296+
297+
for (auto & msg : conversation)
298+
{
299+
std::free(const_cast<char *>(msg.content));
300+
}
301+
279302
conversation.clear();
280303
#if YARP_VERSION_COMPARE(>=, 3, 12, 0)
281304
return yarp::dev::ReturnValue::return_code::return_value_ok;

libraries/YarpPlugins/LlamaGPT/LlamaGPT.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class LlamaGPT : public yarp::dev::DeviceDriver,
5353

5454
private:
5555
llama_model * model {nullptr};
56-
std::vector<yarp::dev::LLM_Message> conversation;
56+
std::vector<llama_chat_message> conversation;
5757
};
5858

5959
#endif // __LLAMA_GPT_HPP__

0 commit comments

Comments
 (0)