Skip to content

Commit

Permalink
find_stop_word fixed by triming string
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Nov 22, 2024
1 parent 9218aee commit ca71f17
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions llama_ros/src/llama_ros/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,22 +845,40 @@ Llama::find_stop(std::vector<struct CompletionOutput> completion_result_list,
return NO_STOP;
}

inline std::string trim(const std::string &str) {

// find the position of the first non-whitespace character
size_t start = str.find_first_not_of(" \t\n\r\f\v");

// if the string is all whitespace, return an empty string
if (start == std::string::npos) {
return "";
}

// find the position of the last non-whitespace character
size_t end = str.find_last_not_of(" \t\n\r\f\v");

// return the substring that excludes leading and trailing whitespace
return str.substr(start, end - start + 1);
}

StopType Llama::find_stop_word(
std::vector<struct CompletionOutput> completion_result_list,
std::string stopping_word) {

std::string completion_text = "";
for (auto c : completion_result_list) {
completion_text.append(this->detokenize({c.token}));
completion_text.append(trim(this->detokenize({c.token})));
}

for (size_t i = 0; i < completion_text.size(); i++) {
for (size_t i = 0; i < completion_text.size() && i < stopping_word.size();
i++) {
if (completion_text.at(i) != stopping_word.at(i)) {
return NO_STOP;
}
}

if (completion_text.size() == stopping_word.size()) {
if (completion_text.size() >= stopping_word.size()) {
return FULL_STOP;
} else {
return PARTIAL_STOP;
Expand Down

0 comments on commit ca71f17

Please sign in to comment.