Skip to content

Commit

Permalink
Fix index used by LocalDocs when tool calling/thinking is active (#3451)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <[email protected]>
  • Loading branch information
cebtenzzre authored Feb 3, 2025
1 parent 6bfa014 commit 9131f4c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 18 deletions.
6 changes: 6 additions & 0 deletions gpt4all-chat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

## [Unreleased]

### Fixed
- Fix "index N is not a prompt" when using LocalDocs with reasoning ([#3451](https://github.com/nomic-ai/gpt4all/pull/3451)

## [3.8.0] - 2025-01-30

### Added
Expand Down Expand Up @@ -283,6 +288,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694))
- Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701))

[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.8.0...HEAD
[3.8.0]: https://github.com/nomic-ai/gpt4all/compare/v3.7.0...v3.8.0
[3.7.0]: https://github.com/nomic-ai/gpt4all/compare/v3.6.1...v3.7.0
[3.6.1]: https://github.com/nomic-ai/gpt4all/compare/v3.6.0...v3.6.1
Expand Down
9 changes: 5 additions & 4 deletions gpt4all-chat/src/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ std::vector<MessageItem> ChatLLM::forkConversation(const QString &prompt) const
conversation.reserve(items.size() + 1);
conversation.assign(items.begin(), items.end());
}
conversation.emplace_back(MessageItem::Type::Prompt, prompt.toUtf8());
qsizetype nextIndex = conversation.empty() ? 0 : conversation.back().index().value() + 1;
conversation.emplace_back(nextIndex, MessageItem::Type::Prompt, prompt.toUtf8());
return conversation;
}

Expand Down Expand Up @@ -801,7 +802,7 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
json::array_t messages;
messages.reserve(useSystem + items.size());
if (useSystem) {
systemItem = std::make_unique<MessageItem>(MessageItem::Type::System, systemMessage.toUtf8());
systemItem = std::make_unique<MessageItem>(MessageItem::system_tag, systemMessage.toUtf8());
messages.emplace_back(makeMap(*systemItem));
}
for (auto &item : items)
Expand Down Expand Up @@ -855,14 +856,14 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
// Find the prompt that represents the query. Server chats are flexible and may not have one.
auto items = getChat();
if (auto peer = m_chatModel->getPeer(items, items.end() - 1)) // peer of response
query = { *peer - items.begin(), (*peer)->content() };
query = { (*peer)->index().value(), (*peer)->content() };
}

if (query) {
auto &[promptIndex, queryStr] = *query;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks
m_chatModel->updateSources(promptIndex + startOffset, databaseResults);
m_chatModel->updateSources(promptIndex, databaseResults);
emit databaseResultsChanged(databaseResults);
}
}
Expand Down
49 changes: 35 additions & 14 deletions gpt4all-chat/src/chatmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,28 @@ class MessageItem
public:
enum class Type { System, Prompt, Response, ToolResponse };

MessageItem(Type type, QString content)
: m_type(type), m_content(std::move(content)) {}
struct system_tag_t { explicit system_tag_t() = default; };
static inline constexpr system_tag_t system_tag = system_tag_t{};

MessageItem(Type type, QString content, const QList<ResultInfo> &sources, const QList<PromptAttachment> &promptAttachments)
: m_type(type), m_content(std::move(content)), m_sources(sources), m_promptAttachments(promptAttachments) {}
MessageItem(qsizetype index, Type type, QString content)
: m_index(index), m_type(type), m_content(std::move(content))
{
Q_ASSERT(type != Type::System); // use system_tag constructor
}

// Construct a system message with no index, since they are never stored in the chat
MessageItem(system_tag_t, QString content)
: m_type(Type::System), m_content(std::move(content)) {}

MessageItem(qsizetype index, Type type, QString content, const QList<ResultInfo> &sources, const QList<PromptAttachment> &promptAttachments)
: m_index(index)
, m_type(type)
, m_content(std::move(content))
, m_sources(sources)
, m_promptAttachments(promptAttachments) {}

// index of the parent ChatItem (system, prompt, response) in its container
std::optional<qsizetype> index() const { return m_index; }

Type type() const { return m_type; }
const QString &content() const { return m_content; }
Expand Down Expand Up @@ -126,10 +143,11 @@ class MessageItem
}

private:
Type m_type;
QString m_content;
QList<ResultInfo> m_sources;
QList<PromptAttachment> m_promptAttachments;
std::optional<qsizetype> m_index;
Type m_type;
QString m_content;
QList<ResultInfo> m_sources;
QList<PromptAttachment> m_promptAttachments;
};
Q_DECLARE_METATYPE(MessageItem)

Expand Down Expand Up @@ -399,7 +417,7 @@ class ChatItem : public QObject
Q_UNREACHABLE();
}

MessageItem asMessageItem() const
MessageItem asMessageItem(qsizetype index) const
{
MessageItem::Type msgType;
switch (auto typ = type()) {
Expand All @@ -413,7 +431,7 @@ class ChatItem : public QObject
case Think:
throw std::invalid_argument(fmt::format("cannot convert ChatItem type {} to message item", int(typ)));
}
return { msgType, flattenedContent(), sources, promptAttachments };
return { index, msgType, flattenedContent(), sources, promptAttachments };
}

static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources);
Expand Down Expand Up @@ -537,6 +555,7 @@ class ChatModel : public QAbstractListModel
return std::nullopt;
}

// FIXME(jared): this should really be done at the parent level, not the sub-item level
static std::optional<qsizetype> getPeerInternal(const MessageItem *arr, qsizetype size, qsizetype index)
{
qsizetype peer;
Expand Down Expand Up @@ -1114,10 +1133,12 @@ class ChatModel : public QAbstractListModel
// A flattened version of the chat item tree used by the backend and jinja
QMutexLocker locker(&m_mutex);
std::vector<MessageItem> chatItems;
for (const ChatItem *item : m_chatItems) {
chatItems.reserve(chatItems.size() + item->subItems.size() + 1);
ranges::copy(item->subItems | views::transform(&ChatItem::asMessageItem), std::back_inserter(chatItems));
chatItems.push_back(item->asMessageItem());
for (qsizetype i : views::iota(0, m_chatItems.size())) {
auto *parent = m_chatItems.at(i);
chatItems.reserve(chatItems.size() + parent->subItems.size() + 1);
ranges::copy(parent->subItems | views::transform([&](auto *s) { return s->asMessageItem(i); }),
std::back_inserter(chatItems));
chatItems.push_back(parent->asMessageItem(i));
}
return chatItems;
}
Expand Down

0 comments on commit 9131f4c

Please sign in to comment.