Skip to content

Commit

Permalink
Add support for backfilling messages
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jan 24, 2024
1 parent 4a57a0b commit 435fde2
Show file tree
Hide file tree
Showing 17 changed files with 882 additions and 40 deletions.
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
* [x] Message unsend
* [x] Message reactions
* [x] Message edits
* [ ] Message history
* [x] Message history
* [ ] Presence
* [x] Typing notifications
* [x] Read receipts
Expand Down
491 changes: 491 additions & 0 deletions backfill.go

Large diffs are not rendered by default.

15 changes: 14 additions & 1 deletion config/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,20 @@ type BridgeConfig struct {
Deadline time.Duration `yaml:"-"`
} `yaml:"message_handling_timeout"`

CommandPrefix string `yaml:"command_prefix"`
CommandPrefix string `yaml:"command_prefix"`

Backfill struct {
Enabled bool `yaml:"enabled"`
HistoryFetchCount int `yaml:"history_fetch_count"`
CatchupFetchCount int `yaml:"catchup_fetch_count"`
Queue struct {
PagesAtOnce int `yaml:"pages_at_once"`
MaxPages int `yaml:"max_pages"`
SleepBetweenTasks time.Duration `yaml:"sleep_between_tasks"`
DontFetchXMA bool `yaml:"dont_fetch_xma"`
} `yaml:"queue"`
} `yaml:"backfill"`

ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"`

Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
Expand Down
8 changes: 8 additions & 0 deletions config/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "double_puppet_allow_discovery")
helper.Copy(up.Map, "bridge", "login_shared_secret_map")
helper.Copy(up.Str, "bridge", "command_prefix")
helper.Copy(up.Bool, "bridge", "backfill", "enabled")
helper.Copy(up.Int, "bridge", "backfill", "history_fetch_count")
helper.Copy(up.Int, "bridge", "backfill", "catchup_fetch_count")
helper.Copy(up.Int, "bridge", "backfill", "queue", "pages_at_once")
helper.Copy(up.Int, "bridge", "backfill", "queue", "max_pages")
helper.Copy(up.Str, "bridge", "backfill", "queue", "sleep_between_tasks")
helper.Copy(up.Bool, "bridge", "backfill", "queue", "dont_fetch_xma")
helper.Copy(up.Str, "bridge", "management_room_text", "welcome")
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected")
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected")
Expand Down Expand Up @@ -153,6 +160,7 @@ var SpacedBlocks = [][]string{
{"bridge"},
{"bridge", "personal_filtering_spaces"},
{"bridge", "command_prefix"},
{"bridge", "backfill"},
{"bridge", "management_room_text"},
{"bridge", "encryption"},
{"bridge", "provisioning"},
Expand Down
66 changes: 65 additions & 1 deletion database/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"

"go.mau.fi/util/dbutil"
Expand Down Expand Up @@ -57,7 +59,9 @@ const (
INSERT INTO message (id, part_index, thread_id, thread_receiver, msg_sender, otid, mxid, mx_room, timestamp, edit_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
deleteMessageQuery = `
insertQueryValuePlaceholder = `($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`
bulkInsertPlaceholderTemplate = `($%d, $%d, $1, $2, $%d, $%d, $%d, $3, $%d, $%d)`
deleteMessageQuery = `
DELETE FROM message
WHERE id=$1 AND thread_receiver=$2 AND part_index=$3
`
Expand All @@ -66,6 +70,12 @@ const (
`
)

func init() {
if strings.ReplaceAll(insertMessageQuery, insertQueryValuePlaceholder, "meow") == insertMessageQuery {
panic("Bulk insert query placeholder not found")
}
}

type MessageQuery struct {
*dbutil.QueryHelper[*Message]
}
Expand Down Expand Up @@ -119,6 +129,60 @@ func (mq *MessageQuery) FindEditTargetPortal(ctx context.Context, id string, rec
return
}

type bulkInserter[T any] interface {
GetDB() *dbutil.Database
BulkInsertChunk(context.Context, PortalKey, id.RoomID, []T) error
}

const BulkInsertChunkSize = 100

func doBulkInsert[T any](q bulkInserter[T], ctx context.Context, thread PortalKey, roomID id.RoomID, entries []T) error {
if len(entries) == 0 {
return nil
}
return q.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
for i := 0; i < len(entries); i += BulkInsertChunkSize {
messageChunk := entries[i:]
if len(messageChunk) > BulkInsertChunkSize {
messageChunk = messageChunk[:BulkInsertChunkSize]
}
err := q.BulkInsertChunk(ctx, thread, roomID, messageChunk)
if err != nil {
return err
}
}
return nil
})
}

func (mq *MessageQuery) BulkInsert(ctx context.Context, thread PortalKey, roomID id.RoomID, messages []*Message) error {
return doBulkInsert[*Message](mq, ctx, thread, roomID, messages)
}

func (mq *MessageQuery) BulkInsertChunk(ctx context.Context, thread PortalKey, roomID id.RoomID, messages []*Message) error {
if len(messages) == 0 {
return nil
}
placeholders := make([]string, len(messages))
values := make([]any, 3+len(messages)*7)
values[0] = thread.ThreadID
values[1] = thread.Receiver
values[2] = roomID
for i, msg := range messages {
baseIndex := 3 + i*7
placeholders[i] = fmt.Sprintf(bulkInsertPlaceholderTemplate, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
values[baseIndex] = msg.ID
values[baseIndex+1] = msg.PartIndex
values[baseIndex+2] = msg.Sender
values[baseIndex+3] = msg.OTID
values[baseIndex+4] = msg.MXID
values[baseIndex+5] = msg.Timestamp.UnixMilli()
values[baseIndex+6] = msg.EditCount
}
query := strings.ReplaceAll(insertMessageQuery, insertQueryValuePlaceholder, strings.Join(placeholders, ","))
return mq.Exec(ctx, query, values...)
}

func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var timestamp int64
err := row.Scan(
Expand Down
18 changes: 14 additions & 4 deletions database/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
portalBaseSelect = `
SELECT thread_id, receiver, thread_type, mxid,
name, avatar_id, avatar_url, name_set, avatar_set,
encrypted, relay_user_id
encrypted, relay_user_id, oldest_message_id, oldest_message_ts, more_to_backfill
FROM portal
`
getPortalByMXIDQuery = portalBaseSelect + `WHERE mxid=$1`
Expand All @@ -47,14 +47,14 @@ const (
INSERT INTO portal (
thread_id, receiver, thread_type, mxid,
name, avatar_id, avatar_url, name_set, avatar_set,
encrypted, relay_user_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
encrypted, relay_user_id, oldest_message_id, oldest_message_ts, more_to_backfill
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
updatePortalQuery = `
UPDATE portal SET
thread_type=$3, mxid=$4,
name=$5, avatar_id=$6, avatar_url=$7, name_set=$8, avatar_set=$9,
encrypted=$10, relay_user_id=$11
encrypted=$10, relay_user_id=$11, oldest_message_id=$12, oldest_message_ts=$13, more_to_backfill=$14
WHERE thread_id=$1 AND receiver=$2
`
deletePortalQuery = `DELETE FROM portal WHERE thread_id=$1 AND receiver=$2`
Expand Down Expand Up @@ -82,6 +82,10 @@ type Portal struct {
AvatarSet bool
Encrypted bool
RelayUserID id.UserID

OldestMessageID string
OldestMessageTS int64
MoreToBackfill bool
}

func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
Expand Down Expand Up @@ -138,6 +142,9 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
&p.AvatarSet,
&p.Encrypted,
&p.RelayUserID,
&p.OldestMessageID,
&p.OldestMessageTS,
&p.MoreToBackfill,
)
if err != nil {
return nil, err
Expand All @@ -159,6 +166,9 @@ func (p *Portal) sqlVariables() []any {
p.AvatarSet,
p.Encrypted,
p.RelayUserID,
p.OldestMessageID,
p.OldestMessageTS,
p.MoreToBackfill,
}
}

Expand Down
42 changes: 41 additions & 1 deletion database/reaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package database

import (
"context"
"fmt"
"strings"

"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
Expand All @@ -35,7 +37,14 @@ const (
INSERT INTO reaction (message_id, thread_id, thread_receiver, reaction_sender, emoji, mxid, mx_room)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
updateReactionQuery = `
bulkInsertReactionQuery = `
INSERT INTO reaction (message_id, thread_id, thread_receiver, reaction_sender, emoji, mxid, mx_room)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (message_id, thread_receiver, reaction_sender) DO UPDATE SET mxid=excluded.mxid, emoji=excluded.emoji
`
bulkInsertReactionQueryValuePlaceholder = `($1, $2, $3, $4, $5, $6, $7)`
bulkInsertReactionPlaceholderTemplate = `($%d, $1, $2, $%d, $%d, $%d, $3)`
updateReactionQuery = `
UPDATE reaction
SET mxid=$1, emoji=$2
WHERE message_id=$3 AND thread_receiver=$4 AND reaction_sender=$5
Expand All @@ -45,6 +54,12 @@ const (
`
)

func init() {
if strings.ReplaceAll(bulkInsertReactionQuery, bulkInsertReactionQueryValuePlaceholder, "meow") == bulkInsertReactionQuery {
panic("Bulk insert query placeholder not found")
}
}

type ReactionQuery struct {
*dbutil.QueryHelper[*Reaction]
}
Expand Down Expand Up @@ -75,6 +90,31 @@ func (rq *ReactionQuery) GetByID(ctx context.Context, msgID string, threadReceiv
return rq.QueryOne(ctx, getReactionByIDQuery, msgID, threadReceiver, reactionSender)
}

func (rq *ReactionQuery) BulkInsert(ctx context.Context, thread PortalKey, roomID id.RoomID, reactions []*Reaction) error {
return doBulkInsert[*Reaction](rq, ctx, thread, roomID, reactions)
}

func (rq *ReactionQuery) BulkInsertChunk(ctx context.Context, thread PortalKey, roomID id.RoomID, reactions []*Reaction) error {
if len(reactions) == 0 {
return nil
}
placeholders := make([]string, len(reactions))
values := make([]any, 3+len(reactions)*4)
values[0] = thread.ThreadID
values[1] = thread.Receiver
values[2] = roomID
for i, react := range reactions {
baseIndex := 3 + i*4
placeholders[i] = fmt.Sprintf(bulkInsertReactionPlaceholderTemplate, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4)
values[baseIndex] = react.MessageID
values[baseIndex+1] = react.Sender
values[baseIndex+2] = react.Emoji
values[baseIndex+3] = react.MXID
}
query := strings.ReplaceAll(bulkInsertReactionQuery, bulkInsertReactionQueryValuePlaceholder, strings.Join(placeholders, ","))
return rq.Exec(ctx, query, values...)
}

func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
return dbutil.ValueOrErr(r, row.Scan(
&r.MessageID, &r.ThreadID, &r.ThreadReceiver, &r.Sender, &r.Emoji, &r.MXID, &r.RoomID,
Expand Down
19 changes: 14 additions & 5 deletions database/upgrades/00-latest.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- v0 -> v2: Latest revision
-- v0 -> v3: Latest revision

CREATE TABLE portal (
thread_id BIGINT NOT NULL,
Expand All @@ -15,6 +15,10 @@ CREATE TABLE portal (
encrypted BOOLEAN NOT NULL DEFAULT false,
relay_user_id TEXT NOT NULL,

oldest_message_id TEXT NOT NULL,
oldest_message_ts BIGINT NOT NULL,
more_to_backfill BOOLEAN NOT NULL,

PRIMARY KEY (thread_id, receiver),
CONSTRAINT portal_mxid_unique UNIQUE(mxid)
);
Expand Down Expand Up @@ -48,10 +52,15 @@ CREATE TABLE "user" (
);

CREATE TABLE user_portal (
user_mxid TEXT,
portal_thread_id BIGINT,
portal_receiver BIGINT,
in_space BOOLEAN NOT NULL DEFAULT false,
user_mxid TEXT NOT NULL,
portal_thread_id BIGINT NOT NULL,
portal_receiver BIGINT NOT NULL,

in_space BOOLEAN NOT NULL DEFAULT false,

backfill_priority INTEGER NOT NULL DEFAULT 0,
backfill_max_pages INTEGER NOT NULL DEFAULT 0,
backfill_dispatched_at BIGINT NOT NULL DEFAULT 0,

PRIMARY KEY (user_mxid, portal_thread_id, portal_receiver),
CONSTRAINT user_portal_user_fkey FOREIGN KEY (user_mxid)
Expand Down
20 changes: 20 additions & 0 deletions database/upgrades/03-backfill-queue.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- v3: Add backfill queue
ALTER TABLE portal ADD COLUMN oldest_message_id TEXT NOT NULL DEFAULT '';
ALTER TABLE portal ADD COLUMN oldest_message_ts BIGINT NOT NULL DEFAULT 0;
ALTER TABLE portal ADD COLUMN more_to_backfill BOOL NOT NULL DEFAULT true;
UPDATE portal SET (oldest_message_id, oldest_message_ts) = (
SELECT id, timestamp
FROM message
WHERE thread_id = portal.thread_id
AND thread_receiver = portal.receiver
ORDER BY timestamp ASC
LIMIT 1
);
-- only: postgres for next 3 lines
ALTER TABLE portal ALTER COLUMN oldest_message_id DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN oldest_message_ts DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN more_to_backfill DROP DEFAULT;

ALTER TABLE user_portal ADD COLUMN backfill_priority INTEGER NOT NULL DEFAULT 0;
ALTER TABLE user_portal ADD COLUMN backfill_max_pages INTEGER NOT NULL DEFAULT 0;
ALTER TABLE user_portal ADD COLUMN backfill_dispatched_at BIGINT NOT NULL DEFAULT 0;
42 changes: 42 additions & 0 deletions database/userportal.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"database/sql"
"errors"
"time"

"github.com/rs/zerolog"
)
Expand All @@ -30,6 +31,17 @@ const (
INSERT INTO user_portal (user_mxid, portal_thread_id, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_thread_id, portal_receiver) DO UPDATE SET in_space=true
`
putBackfillTask = `
INSERT INTO user_portal (user_mxid, portal_thread_id, portal_receiver, backfill_priority, backfill_max_pages, backfill_dispatched_at) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_thread_id, portal_receiver) DO UPDATE
SET backfill_priority=excluded.backfill_priority, backfill_max_pages=excluded.backfill_max_pages, backfill_dispatched_at=excluded.backfill_dispatched_at
`
getNextBackfillTask = `
SELECT portal_thread_id, portal_receiver, backfill_priority, backfill_max_pages, backfill_dispatched_at
FROM user_portal
WHERE backfill_max_pages=-1 OR backfill_max_pages>0
ORDER BY backfill_priority DESC, backfill_dispatched_at LIMIT 1
`
)

func (u *User) IsInSpace(ctx context.Context, portal PortalKey) bool {
Expand Down Expand Up @@ -65,6 +77,36 @@ func (u *User) MarkInSpace(ctx context.Context, portal PortalKey) {
}
}

type BackfillTask struct {
Key PortalKey
Priority int
MaxPages int
DispatchedAt time.Time
}

func (u *User) PutBackfillTask(ctx context.Context, task BackfillTask) {
err := u.qh.Exec(ctx, putBackfillTask, u.MXID, task.Key.ThreadID, task.Key.Receiver, task.Priority, task.MaxPages, task.DispatchedAt.UnixMilli())
if err != nil {
zerolog.Ctx(ctx).Err(err).
Str("user_id", u.MXID.String()).
Any("portal_key", task.Key).
Msg("Failed to save backfill task")
}
}

func (u *User) GetNextBackfillTask(ctx context.Context) (*BackfillTask, error) {
var task BackfillTask
var dispatchedAt int64
err := u.qh.GetDB().QueryRow(ctx, getNextBackfillTask).Scan(&task.Key.ThreadID, &task.Key.Receiver, &task.Priority, &task.MaxPages, &dispatchedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
task.DispatchedAt = time.UnixMilli(dispatchedAt)
return &task, nil
}

func (u *User) RemoveInSpaceCache(key PortalKey) {
u.inSpaceCacheLock.Lock()
defer u.inSpaceCacheLock.Unlock()
Expand Down
Loading

0 comments on commit 435fde2

Please sign in to comment.