Skip to content

Commit

Permalink
adding get metadata functions, msgs and service
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Nov 20, 2024
1 parent f3aca51 commit 1ab99c4
Show file tree
Hide file tree
Showing 14 changed files with 345 additions and 83 deletions.
4 changes: 4 additions & 0 deletions llama_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"msg/SamplingConfig.msg"
"msg/Message.msg"
"msg/LoRA.msg"
"msg/GeneralInfo.msg"
"msg/TokenizerInfo.msg"
"msg/Metadata.msg"
"action/GenerateResponse.action"
"srv/GenerateEmbeddings.srv"
"srv/Tokenize.srv"
Expand All @@ -28,6 +31,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/ListLoRAs.srv"
"srv/UpdateLoRAs.srv"
"srv/RerankDocuments.srv"
"srv/GetMetadata.srv"
DEPENDENCIES sensor_msgs
)

Expand Down
14 changes: 14 additions & 0 deletions llama_msgs/msg/GeneralInfo.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
string architecture # Model architecture name
string description # Description field
string name # Model name
string basename # Base name of the model
string size_label # Label name of the model
string file_type # File type of GGUF
string license # License of the model
string license_link # Link to the license
string url # URL to the model page
string repo_url # URL to the model repository
string[] tags # Tags of the model
string[] languages # Languages supported by the model
string quantized_by # Name of the use that quantized this model
int32 quantization_version # Quantization version
2 changes: 2 additions & 0 deletions llama_msgs/msg/Metadata.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
GeneralInfo general # General info
TokenizerInfo tokenizer # Tokenizer info
6 changes: 6 additions & 0 deletions llama_msgs/msg/TokenizerInfo.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
string model # Tokenizer model name
int32 eos_token_id # eos token of the tokenizer
int32 padding_token_id # padding token of the tokenizer
int32 bos_token_id # bos token of the tokenizer
bool add_bos_token # Whether to add bos
string chat_template # Chat template
2 changes: 2 additions & 0 deletions llama_msgs/srv/GetMetadata.srv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
---
Metadata metadata # Metadata info
106 changes: 75 additions & 31 deletions llama_ros/include/llama_ros/llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>

#include "common.h"
#include "json.hpp"
#include "llama.h"
#include "llama_utils/spinner.hpp"
#include "sampling.h"

#include "llama_utils/spinner.hpp"

// llama logs
#define LLAMA_LOG_ERROR(text, ...) \
fprintf(stderr, "[ERROR] " text "\n", ##__VA_ARGS__)
Expand All @@ -42,44 +45,78 @@
#define LLAMA_LOG_INFO(text, ...) \
fprintf(stderr, "[INFO] " text "\n", ##__VA_ARGS__)

namespace llama_ros {

// llama structs
struct token_prob {
struct TokenProb {
llama_token token;
float probability;
};

struct lora {
struct LoRA {
int id;
std::string path;
float scale;
};

struct completion_output {
std::vector<token_prob> probs;
struct CompletionOutput {
std::vector<TokenProb> probs;
llama_token token;
};

enum stop_type {
enum StopType {
NO_STOP,
FULL_STOP,
PARTIAL_STOP,
CANCEL,
ABORT,
};

struct response_output {
std::vector<completion_output> completions;
stop_type stop;
struct ResponseOutput {
std::vector<CompletionOutput> completions;
StopType stop;
};

struct embeddings_ouput {
struct EmbeddingsOuput {
std::vector<float> embeddings;
int32_t n_tokens;
};

namespace llama_ros {
struct Metadata {
struct GeneralInfo {
std::string architecture;
std::string description;
std::string name;
std::string basename;
std::string size_label;
std::string file_type;
std::string license;
std::string license_link;
std::string url;
std::string repo_url;
std::vector<std::string> tags;
std::vector<std::string> languages;
std::string quantized_by;
int quantization_version;
};

struct TokenizerInfo {
std::string model;
int eos_token_id;
int padding_token_id;
int bos_token_id;
bool add_bos_token;
std::string chat_template;
};

int version;
int tensor_count;
int kv_count;
GeneralInfo general;
TokenizerInfo tokenizer;
};

using GenerateResponseCallback = std::function<void(struct completion_output)>;
using GenerateResponseCallback = std::function<void(struct CompletionOutput)>;

class Llama {

Expand All @@ -97,34 +134,41 @@ class Llama {

std::string format_chat_prompt(std::vector<struct common_chat_msg> chat_msgs,
bool add_ass);
std::vector<struct lora> list_loras();
void update_loras(std::vector<struct lora> loras);
std::vector<struct LoRA> list_loras();
void update_loras(std::vector<struct LoRA> loras);

std::vector<llama_token>
truncate_tokens(const std::vector<llama_token> &tokens, int limit_size,
bool add_eos = true);
embeddings_ouput generate_embeddings(const std::string &input_prompt,
int normalization = 2);
embeddings_ouput generate_embeddings(const std::vector<llama_token> &tokens,
int normalization = 2);
struct EmbeddingsOuput generate_embeddings(const std::string &input_prompt,
int normalization = 2);
struct EmbeddingsOuput
generate_embeddings(const std::vector<llama_token> &tokens,
int normalization = 2);
float rank_document(const std::string &query, const std::string &document);
std::vector<float> rank_documents(const std::string &query,
const std::vector<std::string> &documents);

response_output generate_response(const std::string &input_prompt,
struct common_sampler_params sparams,
GenerateResponseCallback callbakc = nullptr,
std::vector<std::string> stop = {});
response_output generate_response(const std::string &input_prompt,
GenerateResponseCallback callbakc = nullptr,
std::vector<std::string> stop = {});
struct ResponseOutput
generate_response(const std::string &input_prompt,
struct common_sampler_params sparams,
GenerateResponseCallback callbakc = nullptr,
std::vector<std::string> stop = {});
struct ResponseOutput
generate_response(const std::string &input_prompt,
GenerateResponseCallback callbakc = nullptr,
std::vector<std::string> stop = {});

const struct llama_context *get_ctx() { return this->ctx; }
const struct llama_model *get_model() { return this->model; }
int get_n_ctx() { return llama_n_ctx(this->ctx); }
int get_n_ctx_train() { return llama_n_ctx_train(this->model); }
int get_n_embd() { return llama_n_embd(this->model); }
int get_n_vocab() { return llama_n_vocab(this->model); }

std::string get_metada(const std::string &key, size_t size);
struct Metadata get_metada();

bool is_embedding() { return this->params.embedding; }
bool is_reranking() { return this->params.reranking; }
bool add_bos_token() { return llama_add_bos_token(this->model); }
Expand Down Expand Up @@ -156,11 +200,11 @@ class Llama {
virtual void load_prompt(const std::string &input_prompt, bool add_pfx,
bool add_sfx);

stop_type
find_stop(std::vector<struct completion_output> completion_result_list,
StopType
find_stop(std::vector<struct CompletionOutput> completion_result_list,
std::vector<std::string> stopping_words);
stop_type
find_stop_word(std::vector<struct completion_output> completion_result_list,
StopType
find_stop_word(std::vector<struct CompletionOutput> completion_result_list,
std::string stopping_word);

bool eval_system_prompt();
Expand All @@ -170,8 +214,8 @@ class Llama {
bool eval(std::vector<llama_token> tokens);
bool eval(struct llama_batch batch);

std::vector<token_prob> get_probs();
struct completion_output sample();
std::vector<struct TokenProb> get_probs();
struct CompletionOutput sample();

private:
// lock
Expand Down
11 changes: 9 additions & 2 deletions llama_ros/include/llama_ros/llama_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "llama_msgs/srv/detokenize.hpp"
#include "llama_msgs/srv/format_chat_messages.hpp"
#include "llama_msgs/srv/generate_embeddings.hpp"
#include "llama_msgs/srv/get_metadata.hpp"
#include "llama_msgs/srv/list_lo_r_as.hpp"
#include "llama_msgs/srv/rerank_documents.hpp"
#include "llama_msgs/srv/tokenize.hpp"
Expand Down Expand Up @@ -71,7 +72,7 @@ class LlamaNode : public rclcpp_lifecycle::LifecycleNode {
protected:
std::unique_ptr<Llama> llama;
bool params_declared;
struct llama_utils::llama_params params;
struct llama_utils::LlamaParams params;
std::shared_ptr<GoalHandleGenerateResponse> goal_handle_;

virtual void create_llama();
Expand All @@ -80,10 +81,12 @@ class LlamaNode : public rclcpp_lifecycle::LifecycleNode {
virtual bool goal_empty(std::shared_ptr<const GenerateResponse::Goal> goal);
virtual void
execute(const std::shared_ptr<GoalHandleGenerateResponse> goal_handle);
void send_text(const struct completion_output &completion);
void send_text(const struct CompletionOutput &completion);

private:
// ros2
rclcpp::Service<llama_msgs::srv::GetMetadata>::SharedPtr
get_metadata_service_;
rclcpp::Service<llama_msgs::srv::Tokenize>::SharedPtr tokenize_service_;
rclcpp::Service<llama_msgs::srv::Detokenize>::SharedPtr detokenize_service_;
rclcpp::Service<llama_msgs::srv::GenerateEmbeddings>::SharedPtr
Expand All @@ -99,6 +102,10 @@ class LlamaNode : public rclcpp_lifecycle::LifecycleNode {
generate_response_action_server_;

// methods
void get_metadata_service_callback(
const std::shared_ptr<llama_msgs::srv::GetMetadata::Request> request,
std::shared_ptr<llama_msgs::srv::GetMetadata::Response> response);

void tokenize_service_callback(
const std::shared_ptr<llama_msgs::srv::Tokenize::Request> request,
std::shared_ptr<llama_msgs::srv::Tokenize::Response> response);
Expand Down
6 changes: 3 additions & 3 deletions llama_ros/include/llama_utils/llama_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@

namespace llama_utils {

struct llama_params {
struct LlamaParams {
bool debug;
std::string system_prompt;
struct common_params params;
struct llava_ros::llava_params llava_params;
struct llava_ros::LlavaParams llava_params;
};

void declare_llama_params(
const rclcpp_lifecycle::LifecycleNode::SharedPtr &node);

struct llama_params
struct LlamaParams
get_llama_params(const rclcpp_lifecycle::LifecycleNode::SharedPtr &node);

enum ggml_sched_priority parse_priority(std::string priority);
Expand Down
6 changes: 3 additions & 3 deletions llama_ros/include/llava_ros/llava.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

namespace llava_ros {

struct llava_params {
struct LlavaParams {
std::string image_text = "";
std::string image_prefix = "";
std::string image_suffix = "";
Expand All @@ -48,7 +48,7 @@ class Llava : public llama_ros::Llama {

public:
Llava(const struct common_params &params,
const struct llava_params &llava_params, std::string system_prompt = "",
const struct LlavaParams &llava_params, std::string system_prompt = "",
bool debug = false);
~Llava();

Expand All @@ -64,7 +64,7 @@ class Llava : public llama_ros::Llama {

struct llava_image_embed *image_embed;
struct clip_ctx *ctx_clip;
struct llava_params llava_params;
struct LlavaParams llava_params;

private:
void free_image();
Expand Down
8 changes: 8 additions & 0 deletions llama_ros/llama_ros/llama_client_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from rclpy.executors import MultiThreadedExecutor

from action_msgs.msg import GoalStatus
from llama_msgs.srv import GetMetadata
from llama_msgs.srv import Tokenize
from llama_msgs.srv import Detokenize
from llama_msgs.srv import GenerateEmbeddings
Expand Down Expand Up @@ -89,6 +90,10 @@ def __init__(self, namespace: str = "llama") -> None:
callback_group=self._callback_group,
)

self._get_metadata_srv_client = self.create_client(
GetMetadata, "get_metadata", callback_group=self._callback_group
)

self._tokenize_srv_client = self.create_client(
Tokenize, "tokenize", callback_group=self._callback_group
)
Expand Down Expand Up @@ -119,6 +124,9 @@ def __init__(self, namespace: str = "llama") -> None:
self._spin_thread = Thread(target=self._executor.spin)
self._spin_thread.start()

def get_metadata(self, req: GetMetadata.Request) -> GetMetadata:
return self._get_metadata_srv_client.call(req)

def tokenize(self, req: Tokenize.Request) -> Tokenize.Response:
self._tokenize_srv_client.wait_for_service()
return self._tokenize_srv_client.call(req)
Expand Down
Loading

0 comments on commit 1ab99c4

Please sign in to comment.