diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bde6cc721..da3c2e9f9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,9 +2,9 @@ name: Go on: push: - branches: [master, v9, v9.7, v9.8] + branches: [master, v9, v9.7, '*'] pull_request: - branches: [master, v9, v9.7, v9.8] + branches: [master, v9, v9.7, '*'] permissions: contents: read diff --git a/command.go b/command.go index 3253af6cc..2484210af 100644 --- a/command.go +++ b/command.go @@ -14,9 +14,94 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/routing" "github.com/redis/go-redis/v9/internal/util" ) +// CmdTyper interface for getting command type +type CmdTyper interface { + GetCmdType() CmdType +} + +// CmdTypeGetter interface for getting command type without circular imports +type CmdTypeGetter interface { + GetCmdType() CmdType +} + +type CmdType uint8 + +const ( + CmdTypeGeneric CmdType = iota + CmdTypeString + CmdTypeInt + CmdTypeBool + CmdTypeFloat + CmdTypeStringSlice + CmdTypeIntSlice + CmdTypeFloatSlice + CmdTypeBoolSlice + CmdTypeMapStringString + CmdTypeMapStringInt + CmdTypeMapStringInterface + CmdTypeMapStringInterfaceSlice + CmdTypeSlice + CmdTypeStatus + CmdTypeDuration + CmdTypeTime + CmdTypeKeyValueSlice + CmdTypeStringStructMap + CmdTypeXMessageSlice + CmdTypeXStreamSlice + CmdTypeXPending + CmdTypeXPendingExt + CmdTypeXAutoClaim + CmdTypeXAutoClaimJustID + CmdTypeXInfoConsumers + CmdTypeXInfoGroups + CmdTypeXInfoStream + CmdTypeXInfoStreamFull + CmdTypeZSlice + CmdTypeZWithKey + CmdTypeScan + CmdTypeClusterSlots + CmdTypeGeoLocation + CmdTypeGeoSearchLocation + CmdTypeGeoPos + CmdTypeCommandsInfo + CmdTypeSlowLog + CmdTypeMapStringStringSlice + CmdTypeMapMapStringInterface + CmdTypeKeyValues + CmdTypeZSliceWithKey + CmdTypeFunctionList + CmdTypeFunctionStats + CmdTypeLCS + CmdTypeKeyFlags + CmdTypeClusterLinks + CmdTypeClusterShards + CmdTypeRankWithScore + CmdTypeClientInfo + CmdTypeACLLog + CmdTypeInfo + CmdTypeMonitor + CmdTypeJSON + CmdTypeJSONSlice + CmdTypeIntPointerSlice + CmdTypeScanDump + CmdTypeBFInfo + CmdTypeCFInfo + CmdTypeCMSInfo + CmdTypeTopKInfo + CmdTypeTDigestInfo + CmdTypeFTSynDump + CmdTypeAggregate + CmdTypeFTInfo + CmdTypeFTSpellCheck + CmdTypeFTSearch + CmdTypeTSTimestampValue + CmdTypeTSTimestampValueSlice +) + type Cmder interface { // command name. // e.g. "set k v ex 10" -> "set", "cluster info" -> "cluster". @@ -34,6 +119,9 @@ type Cmder interface { // e.g. "set k v ex 10" -> "set k v ex 10: OK", "get k" -> "get k: v". String() string + // Clone creates a copy of the command. + Clone() Cmder + stringArg(int) string firstKeyPos() int8 SetFirstKeyPos(int8) @@ -43,6 +131,9 @@ type Cmder interface { readRawReply(rd *proto.Reader) error SetErr(error) Err() error + + // GetCmdType returns the command type for fast value extraction + GetCmdType() CmdType } func setCmdsErr(cmds []Cmder, e error) { @@ -128,6 +219,7 @@ type baseCmd struct { keyPos int8 rawVal interface{} _readTimeout *time.Duration + cmdType CmdType } var _ Cmder = (*Cmd)(nil) @@ -204,6 +296,32 @@ func (cmd *baseCmd) readRawReply(rd *proto.Reader) (err error) { return err } +func (cmd *baseCmd) GetCmdType() CmdType { + return cmd.cmdType +} + +func (cmd *baseCmd) cloneBaseCmd() baseCmd { + var readTimeout *time.Duration + if cmd._readTimeout != nil { + timeout := *cmd._readTimeout + readTimeout = &timeout + } + + // Create a copy of args slice + args := make([]interface{}, len(cmd.args)) + copy(args, cmd.args) + + return baseCmd{ + ctx: cmd.ctx, + args: args, + err: cmd.err, + keyPos: cmd.keyPos, + rawVal: cmd.rawVal, + _readTimeout: readTimeout, + cmdType: cmd.cmdType, + } +} + //------------------------------------------------------------------------------ type Cmd struct { @@ -215,8 +333,9 @@ type Cmd struct { func NewCmd(ctx context.Context, args ...interface{}) *Cmd { return &Cmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeGeneric, }, } } @@ -489,6 +608,13 @@ func (cmd *Cmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *Cmd) Clone() Cmder { + return &Cmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type SliceCmd struct { @@ -502,8 +628,9 @@ var _ Cmder = (*SliceCmd)(nil) func NewSliceCmd(ctx context.Context, args ...interface{}) *SliceCmd { return &SliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeSlice, }, } } @@ -549,6 +676,18 @@ func (cmd *SliceCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *SliceCmd) Clone() Cmder { + var val []interface{} + if cmd.val != nil { + val = make([]interface{}, len(cmd.val)) + copy(val, cmd.val) + } + return &SliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StatusCmd struct { @@ -562,8 +701,9 @@ var _ Cmder = (*StatusCmd)(nil) func NewStatusCmd(ctx context.Context, args ...interface{}) *StatusCmd { return &StatusCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStatus, }, } } @@ -593,6 +733,13 @@ func (cmd *StatusCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *StatusCmd) Clone() Cmder { + return &StatusCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type IntCmd struct { @@ -606,8 +753,9 @@ var _ Cmder = (*IntCmd)(nil) func NewIntCmd(ctx context.Context, args ...interface{}) *IntCmd { return &IntCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeInt, }, } } @@ -637,6 +785,13 @@ func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *IntCmd) Clone() Cmder { + return &IntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type IntSliceCmd struct { @@ -650,8 +805,9 @@ var _ Cmder = (*IntSliceCmd)(nil) func NewIntSliceCmd(ctx context.Context, args ...interface{}) *IntSliceCmd { return &IntSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeIntSlice, }, } } @@ -686,6 +842,18 @@ func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *IntSliceCmd) Clone() Cmder { + var val []int64 + if cmd.val != nil { + val = make([]int64, len(cmd.val)) + copy(val, cmd.val) + } + return &IntSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type DurationCmd struct { @@ -700,8 +868,9 @@ var _ Cmder = (*DurationCmd)(nil) func NewDurationCmd(ctx context.Context, precision time.Duration, args ...interface{}) *DurationCmd { return &DurationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeDuration, }, precision: precision, } @@ -739,6 +908,14 @@ func (cmd *DurationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *DurationCmd) Clone() Cmder { + return &DurationCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + precision: cmd.precision, + } +} + //------------------------------------------------------------------------------ type TimeCmd struct { @@ -752,8 +929,9 @@ var _ Cmder = (*TimeCmd)(nil) func NewTimeCmd(ctx context.Context, args ...interface{}) *TimeCmd { return &TimeCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTime, }, } } @@ -790,6 +968,13 @@ func (cmd *TimeCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *TimeCmd) Clone() Cmder { + return &TimeCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type BoolCmd struct { @@ -803,8 +988,9 @@ var _ Cmder = (*BoolCmd)(nil) func NewBoolCmd(ctx context.Context, args ...interface{}) *BoolCmd { return &BoolCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBool, }, } } @@ -837,6 +1023,13 @@ func (cmd *BoolCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *BoolCmd) Clone() Cmder { + return &BoolCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type StringCmd struct { @@ -850,8 +1043,9 @@ var _ Cmder = (*StringCmd)(nil) func NewStringCmd(ctx context.Context, args ...interface{}) *StringCmd { return &StringCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeString, }, } } @@ -876,7 +1070,7 @@ func (cmd *StringCmd) Bool() (bool, error) { if cmd.err != nil { return false, cmd.err } - return strconv.ParseBool(cmd.val) + return strconv.ParseBool(cmd.Val()) } func (cmd *StringCmd) Int() (int, error) { @@ -941,6 +1135,13 @@ func (cmd *StringCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *StringCmd) Clone() Cmder { + return &StringCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type FloatCmd struct { @@ -954,8 +1155,9 @@ var _ Cmder = (*FloatCmd)(nil) func NewFloatCmd(ctx context.Context, args ...interface{}) *FloatCmd { return &FloatCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFloat, }, } } @@ -981,6 +1183,13 @@ func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *FloatCmd) Clone() Cmder { + return &FloatCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type FloatSliceCmd struct { @@ -994,8 +1203,9 @@ var _ Cmder = (*FloatSliceCmd)(nil) func NewFloatSliceCmd(ctx context.Context, args ...interface{}) *FloatSliceCmd { return &FloatSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFloatSlice, }, } } @@ -1036,6 +1246,18 @@ func (cmd *FloatSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *FloatSliceCmd) Clone() Cmder { + var val []float64 + if cmd.val != nil { + val = make([]float64, len(cmd.val)) + copy(val, cmd.val) + } + return &FloatSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StringSliceCmd struct { @@ -1049,8 +1271,9 @@ var _ Cmder = (*StringSliceCmd)(nil) func NewStringSliceCmd(ctx context.Context, args ...interface{}) *StringSliceCmd { return &StringSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStringSlice, }, } } @@ -1094,6 +1317,18 @@ func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *StringSliceCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &StringSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type KeyValue struct { @@ -1112,8 +1347,9 @@ var _ Cmder = (*KeyValueSliceCmd)(nil) func NewKeyValueSliceCmd(ctx context.Context, args ...interface{}) *KeyValueSliceCmd { return &KeyValueSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyValueSlice, }, } } @@ -1188,6 +1424,18 @@ func (cmd *KeyValueSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl return nil } +func (cmd *KeyValueSliceCmd) Clone() Cmder { + var val []KeyValue + if cmd.val != nil { + val = make([]KeyValue, len(cmd.val)) + copy(val, cmd.val) + } + return &KeyValueSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type BoolSliceCmd struct { @@ -1201,8 +1449,9 @@ var _ Cmder = (*BoolSliceCmd)(nil) func NewBoolSliceCmd(ctx context.Context, args ...interface{}) *BoolSliceCmd { return &BoolSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBoolSlice, }, } } @@ -1237,6 +1486,18 @@ func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *BoolSliceCmd) Clone() Cmder { + var val []bool + if cmd.val != nil { + val = make([]bool, len(cmd.val)) + copy(val, cmd.val) + } + return &BoolSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type MapStringStringCmd struct { @@ -1250,8 +1511,9 @@ var _ Cmder = (*MapStringStringCmd)(nil) func NewMapStringStringCmd(ctx context.Context, args ...interface{}) *MapStringStringCmd { return &MapStringStringCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringString, }, } } @@ -1316,6 +1578,20 @@ func (cmd *MapStringStringCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringStringCmd) Clone() Cmder { + var val map[string]string + if cmd.val != nil { + val = make(map[string]string, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringStringCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type MapStringIntCmd struct { @@ -1329,8 +1605,9 @@ var _ Cmder = (*MapStringIntCmd)(nil) func NewMapStringIntCmd(ctx context.Context, args ...interface{}) *MapStringIntCmd { return &MapStringIntCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInt, }, } } @@ -1373,6 +1650,20 @@ func (cmd *MapStringIntCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringIntCmd) Clone() Cmder { + var val map[string]int64 + if cmd.val != nil { + val = make(map[string]int64, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringIntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------------------------------------------ type MapStringSliceInterfaceCmd struct { baseCmd @@ -1382,8 +1673,9 @@ type MapStringSliceInterfaceCmd struct { func NewMapStringSliceInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringSliceInterfaceCmd { return &MapStringSliceInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterfaceSlice, }, } } @@ -1469,6 +1761,24 @@ func (cmd *MapStringSliceInterfaceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *MapStringSliceInterfaceCmd) Clone() Cmder { + var val map[string][]interface{} + if cmd.val != nil { + val = make(map[string][]interface{}, len(cmd.val)) + for k, v := range cmd.val { + if v != nil { + newSlice := make([]interface{}, len(v)) + copy(newSlice, v) + val[k] = newSlice + } + } + } + return &MapStringSliceInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StringStructMapCmd struct { @@ -1482,8 +1792,9 @@ var _ Cmder = (*StringStructMapCmd)(nil) func NewStringStructMapCmd(ctx context.Context, args ...interface{}) *StringStructMapCmd { return &StringStructMapCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStringStructMap, }, } } @@ -1521,6 +1832,20 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *StringStructMapCmd) Clone() Cmder { + var val map[string]struct{} + if cmd.val != nil { + val = make(map[string]struct{}, len(cmd.val)) + for k := range cmd.val { + val[k] = struct{}{} + } + } + return &StringStructMapCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XMessage struct { @@ -1539,8 +1864,9 @@ var _ Cmder = (*XMessageSliceCmd)(nil) func NewXMessageSliceCmd(ctx context.Context, args ...interface{}) *XMessageSliceCmd { return &XMessageSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXMessageSlice, }, } } @@ -1566,6 +1892,28 @@ func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *XMessageSliceCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Values[k] = v + } + } + } + } + return &XMessageSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { n, err := rd.ReadArrayLen() if err != nil { @@ -1645,8 +1993,9 @@ var _ Cmder = (*XStreamSliceCmd)(nil) func NewXStreamSliceCmd(ctx context.Context, args ...interface{}) *XStreamSliceCmd { return &XStreamSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXStreamSlice, }, } } @@ -1699,6 +2048,36 @@ func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XStreamSliceCmd) Clone() Cmder { + var val []XStream + if cmd.val != nil { + val = make([]XStream, len(cmd.val)) + for i, stream := range cmd.val { + val[i] = XStream{ + Stream: stream.Stream, + } + if stream.Messages != nil { + val[i].Messages = make([]XMessage, len(stream.Messages)) + for j, msg := range stream.Messages { + val[i].Messages[j] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Messages[j].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Messages[j].Values[k] = v + } + } + } + } + } + } + return &XStreamSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XPending struct { @@ -1718,8 +2097,9 @@ var _ Cmder = (*XPendingCmd)(nil) func NewXPendingCmd(ctx context.Context, args ...interface{}) *XPendingCmd { return &XPendingCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXPending, }, } } @@ -1782,6 +2162,27 @@ func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XPendingCmd) Clone() Cmder { + var val *XPending + if cmd.val != nil { + val = &XPending{ + Count: cmd.val.Count, + Lower: cmd.val.Lower, + Higher: cmd.val.Higher, + } + if cmd.val.Consumers != nil { + val.Consumers = make(map[string]int64, len(cmd.val.Consumers)) + for k, v := range cmd.val.Consumers { + val.Consumers[k] = v + } + } + } + return &XPendingCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XPendingExt struct { @@ -1801,8 +2202,9 @@ var _ Cmder = (*XPendingExtCmd)(nil) func NewXPendingExtCmd(ctx context.Context, args ...interface{}) *XPendingExtCmd { return &XPendingExtCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXPendingExt, }, } } @@ -1857,6 +2259,18 @@ func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XPendingExtCmd) Clone() Cmder { + var val []XPendingExt + if cmd.val != nil { + val = make([]XPendingExt, len(cmd.val)) + copy(val, cmd.val) + } + return &XPendingExtCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XAutoClaimCmd struct { @@ -1871,8 +2285,9 @@ var _ Cmder = (*XAutoClaimCmd)(nil) func NewXAutoClaimCmd(ctx context.Context, args ...interface{}) *XAutoClaimCmd { return &XAutoClaimCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaim, }, } } @@ -1927,6 +2342,29 @@ func (cmd *XAutoClaimCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XAutoClaimCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Values[k] = v + } + } + } + } + return &XAutoClaimCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + } +} + //------------------------------------------------------------------------------ type XAutoClaimJustIDCmd struct { @@ -1941,8 +2379,9 @@ var _ Cmder = (*XAutoClaimJustIDCmd)(nil) func NewXAutoClaimJustIDCmd(ctx context.Context, args ...interface{}) *XAutoClaimJustIDCmd { return &XAutoClaimJustIDCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaimJustID, }, } } @@ -2005,6 +2444,19 @@ func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XAutoClaimJustIDCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &XAutoClaimJustIDCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoConsumersCmd struct { @@ -2024,8 +2476,9 @@ var _ Cmder = (*XInfoConsumersCmd)(nil) func NewXInfoConsumersCmd(ctx context.Context, stream string, group string) *XInfoConsumersCmd { return &XInfoConsumersCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "consumers", stream, group}, + ctx: ctx, + args: []interface{}{"xinfo", "consumers", stream, group}, + cmdType: CmdTypeXInfoConsumers, }, } } @@ -2091,6 +2544,18 @@ func (cmd *XInfoConsumersCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoConsumersCmd) Clone() Cmder { + var val []XInfoConsumer + if cmd.val != nil { + val = make([]XInfoConsumer, len(cmd.val)) + copy(val, cmd.val) + } + return &XInfoConsumersCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoGroupsCmd struct { @@ -2112,8 +2577,9 @@ var _ Cmder = (*XInfoGroupsCmd)(nil) func NewXInfoGroupsCmd(ctx context.Context, stream string) *XInfoGroupsCmd { return &XInfoGroupsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "groups", stream}, + ctx: ctx, + args: []interface{}{"xinfo", "groups", stream}, + cmdType: CmdTypeXInfoGroups, }, } } @@ -2199,6 +2665,18 @@ func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoGroupsCmd) Clone() Cmder { + var val []XInfoGroup + if cmd.val != nil { + val = make([]XInfoGroup, len(cmd.val)) + copy(val, cmd.val) + } + return &XInfoGroupsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoStreamCmd struct { @@ -2224,8 +2702,9 @@ var _ Cmder = (*XInfoStreamCmd)(nil) func NewXInfoStreamCmd(ctx context.Context, stream string) *XInfoStreamCmd { return &XInfoStreamCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "stream", stream}, + ctx: ctx, + args: []interface{}{"xinfo", "stream", stream}, + cmdType: CmdTypeXInfoStream, }, } } @@ -2316,6 +2795,45 @@ func (cmd *XInfoStreamCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoStreamCmd) Clone() Cmder { + var val *XInfoStream + if cmd.val != nil { + val = &XInfoStream{ + Length: cmd.val.Length, + RadixTreeKeys: cmd.val.RadixTreeKeys, + RadixTreeNodes: cmd.val.RadixTreeNodes, + Groups: cmd.val.Groups, + LastGeneratedID: cmd.val.LastGeneratedID, + MaxDeletedEntryID: cmd.val.MaxDeletedEntryID, + EntriesAdded: cmd.val.EntriesAdded, + RecordedFirstEntryID: cmd.val.RecordedFirstEntryID, + } + // Clone XMessage fields + val.FirstEntry = XMessage{ + ID: cmd.val.FirstEntry.ID, + } + if cmd.val.FirstEntry.Values != nil { + val.FirstEntry.Values = make(map[string]interface{}, len(cmd.val.FirstEntry.Values)) + for k, v := range cmd.val.FirstEntry.Values { + val.FirstEntry.Values[k] = v + } + } + val.LastEntry = XMessage{ + ID: cmd.val.LastEntry.ID, + } + if cmd.val.LastEntry.Values != nil { + val.LastEntry.Values = make(map[string]interface{}, len(cmd.val.LastEntry.Values)) + for k, v := range cmd.val.LastEntry.Values { + val.LastEntry.Values[k] = v + } + } + } + return &XInfoStreamCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoStreamFullCmd struct { @@ -2371,8 +2889,9 @@ var _ Cmder = (*XInfoStreamFullCmd)(nil) func NewXInfoStreamFullCmd(ctx context.Context, args ...interface{}) *XInfoStreamFullCmd { return &XInfoStreamFullCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXInfoStreamFull, }, } } @@ -2657,6 +3176,45 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { return consumers, nil } +func (cmd *XInfoStreamFullCmd) Clone() Cmder { + var val *XInfoStreamFull + if cmd.val != nil { + val = &XInfoStreamFull{ + Length: cmd.val.Length, + RadixTreeKeys: cmd.val.RadixTreeKeys, + RadixTreeNodes: cmd.val.RadixTreeNodes, + LastGeneratedID: cmd.val.LastGeneratedID, + MaxDeletedEntryID: cmd.val.MaxDeletedEntryID, + EntriesAdded: cmd.val.EntriesAdded, + RecordedFirstEntryID: cmd.val.RecordedFirstEntryID, + } + // Clone Entries + if cmd.val.Entries != nil { + val.Entries = make([]XMessage, len(cmd.val.Entries)) + for i, msg := range cmd.val.Entries { + val.Entries[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val.Entries[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val.Entries[i].Values[k] = v + } + } + } + } + // Clone Groups - simplified copy for now due to complexity + if cmd.val.Groups != nil { + val.Groups = make([]XInfoStreamGroup, len(cmd.val.Groups)) + copy(val.Groups, cmd.val.Groups) + } + } + return &XInfoStreamFullCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ZSliceCmd struct { @@ -2670,8 +3228,9 @@ var _ Cmder = (*ZSliceCmd)(nil) func NewZSliceCmd(ctx context.Context, args ...interface{}) *ZSliceCmd { return &ZSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZSlice, }, } } @@ -2735,6 +3294,18 @@ func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl return nil } +func (cmd *ZSliceCmd) Clone() Cmder { + var val []Z + if cmd.val != nil { + val = make([]Z, len(cmd.val)) + copy(val, cmd.val) + } + return &ZSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ZWithKeyCmd struct { @@ -2748,8 +3319,9 @@ var _ Cmder = (*ZWithKeyCmd)(nil) func NewZWithKeyCmd(ctx context.Context, args ...interface{}) *ZWithKeyCmd { return &ZWithKeyCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZWithKey, }, } } @@ -2789,6 +3361,23 @@ func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ZWithKeyCmd) Clone() Cmder { + var val *ZWithKey + if cmd.val != nil { + val = &ZWithKey{ + Z: Z{ + Score: cmd.val.Score, + Member: cmd.val.Member, + }, + Key: cmd.val.Key, + } + } + return &ZWithKeyCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ScanCmd struct { @@ -2805,8 +3394,9 @@ var _ Cmder = (*ScanCmd)(nil) func NewScanCmd(ctx context.Context, process cmdable, args ...interface{}) *ScanCmd { return &ScanCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeScan, }, process: process, } @@ -2854,6 +3444,20 @@ func (cmd *ScanCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ScanCmd) Clone() Cmder { + var page []string + if cmd.page != nil { + page = make([]string, len(cmd.page)) + copy(page, cmd.page) + } + return &ScanCmd{ + baseCmd: cmd.cloneBaseCmd(), + page: page, + cursor: cmd.cursor, + process: cmd.process, + } +} + // Iterator creates a new ScanIterator. func (cmd *ScanCmd) Iterator() *ScanIterator { return &ScanIterator{ @@ -2886,8 +3490,9 @@ var _ Cmder = (*ClusterSlotsCmd)(nil) func NewClusterSlotsCmd(ctx context.Context, args ...interface{}) *ClusterSlotsCmd { return &ClusterSlotsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterSlots, }, } } @@ -3000,6 +3605,38 @@ func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterSlotsCmd) Clone() Cmder { + var val []ClusterSlot + if cmd.val != nil { + val = make([]ClusterSlot, len(cmd.val)) + for i, slot := range cmd.val { + val[i] = ClusterSlot{ + Start: slot.Start, + End: slot.End, + } + if slot.Nodes != nil { + val[i].Nodes = make([]ClusterNode, len(slot.Nodes)) + for j, node := range slot.Nodes { + val[i].Nodes[j] = ClusterNode{ + ID: node.ID, + Addr: node.Addr, + } + if node.NetworkingMetadata != nil { + val[i].Nodes[j].NetworkingMetadata = make(map[string]string, len(node.NetworkingMetadata)) + for k, v := range node.NetworkingMetadata { + val[i].Nodes[j].NetworkingMetadata[k] = v + } + } + } + } + } + } + return &ClusterSlotsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // GeoLocation is used with GeoAdd to add geospatial location. @@ -3039,8 +3676,9 @@ var _ Cmder = (*GeoLocationCmd)(nil) func NewGeoLocationCmd(ctx context.Context, q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd { return &GeoLocationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: geoLocationArgs(q, args...), + ctx: ctx, + args: geoLocationArgs(q, args...), + cmdType: CmdTypeGeoLocation, }, q: q, } @@ -3148,6 +3786,34 @@ func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoLocationCmd) Clone() Cmder { + var q *GeoRadiusQuery + if cmd.q != nil { + q = &GeoRadiusQuery{ + Radius: cmd.q.Radius, + Unit: cmd.q.Unit, + WithCoord: cmd.q.WithCoord, + WithDist: cmd.q.WithDist, + WithGeoHash: cmd.q.WithGeoHash, + Count: cmd.q.Count, + Sort: cmd.q.Sort, + Store: cmd.q.Store, + StoreDist: cmd.q.StoreDist, + withLen: cmd.q.withLen, + } + } + var locations []GeoLocation + if cmd.locations != nil { + locations = make([]GeoLocation, len(cmd.locations)) + copy(locations, cmd.locations) + } + return &GeoLocationCmd{ + baseCmd: cmd.cloneBaseCmd(), + q: q, + locations: locations, + } +} + //------------------------------------------------------------------------------ // GeoSearchQuery is used for GEOSearch/GEOSearchStore command query. @@ -3255,8 +3921,9 @@ func NewGeoSearchLocationCmd( ) *GeoSearchLocationCmd { return &GeoSearchLocationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: geoSearchLocationArgs(opt, args), + cmdType: CmdTypeGeoSearchLocation, }, opt: opt, } @@ -3329,6 +3996,40 @@ func (cmd *GeoSearchLocationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoSearchLocationCmd) Clone() Cmder { + var opt *GeoSearchLocationQuery + if cmd.opt != nil { + opt = &GeoSearchLocationQuery{ + GeoSearchQuery: GeoSearchQuery{ + Member: cmd.opt.Member, + Longitude: cmd.opt.Longitude, + Latitude: cmd.opt.Latitude, + Radius: cmd.opt.Radius, + RadiusUnit: cmd.opt.RadiusUnit, + BoxWidth: cmd.opt.BoxWidth, + BoxHeight: cmd.opt.BoxHeight, + BoxUnit: cmd.opt.BoxUnit, + Sort: cmd.opt.Sort, + Count: cmd.opt.Count, + CountAny: cmd.opt.CountAny, + }, + WithCoord: cmd.opt.WithCoord, + WithDist: cmd.opt.WithDist, + WithHash: cmd.opt.WithHash, + } + } + var val []GeoLocation + if cmd.val != nil { + val = make([]GeoLocation, len(cmd.val)) + copy(val, cmd.val) + } + return &GeoSearchLocationCmd{ + baseCmd: cmd.cloneBaseCmd(), + opt: opt, + val: val, + } +} + //------------------------------------------------------------------------------ type GeoPos struct { @@ -3346,8 +4047,9 @@ var _ Cmder = (*GeoPosCmd)(nil) func NewGeoPosCmd(ctx context.Context, args ...interface{}) *GeoPosCmd { return &GeoPosCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeGeoPos, }, } } @@ -3403,6 +4105,25 @@ func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoPosCmd) Clone() Cmder { + var val []*GeoPos + if cmd.val != nil { + val = make([]*GeoPos, len(cmd.val)) + for i, pos := range cmd.val { + if pos != nil { + val[i] = &GeoPos{ + Longitude: pos.Longitude, + Latitude: pos.Latitude, + } + } + } + } + return &GeoPosCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type CommandInfo struct { @@ -3414,6 +4135,7 @@ type CommandInfo struct { LastKeyPos int8 StepCount int8 ReadOnly bool + Tips *routing.CommandPolicy } type CommandsInfoCmd struct { @@ -3427,8 +4149,9 @@ var _ Cmder = (*CommandsInfoCmd)(nil) func NewCommandsInfoCmd(ctx context.Context, args ...interface{}) *CommandsInfoCmd { return &CommandsInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCommandsInfo, }, } } @@ -3452,7 +4175,7 @@ func (cmd *CommandsInfoCmd) String() string { func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { const numArgRedis5 = 6 const numArgRedis6 = 7 - const numArgRedis7 = 10 + const numArgRedis7 = 10 // Also matches redis 8 n, err := rd.ReadArrayLen() if err != nil { @@ -3540,9 +4263,34 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { } if nn >= numArgRedis7 { - if err := rd.DiscardNext(); err != nil { + // The 8th argument is an array of tips. + tipsLen, err := rd.ReadArrayLen() + if err != nil { return err } + + rawTips := make(map[string]string, tipsLen) + for f := 0; f < tipsLen; f++ { + tip, err := rd.ReadString() + if err != nil { + return err + } + + // Handle tips that don't have a colon (like "nondeterministic_output") + if !strings.Contains(tip, ":") { + rawTips[tip] = "" + continue + } + + // Handle normal key:value tips + k, v, ok := strings.Cut(tip, ":") + if !ok { + return fmt.Errorf("redis: unexpected tip %q in COMMAND reply", tip) + } + rawTips[k] = v + } + cmdInfo.Tips = parseCommandPolicies(rawTips) + if err := rd.DiscardNext(); err != nil { return err } @@ -3557,6 +4305,39 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *CommandsInfoCmd) Clone() Cmder { + var val map[string]*CommandInfo + if cmd.val != nil { + val = make(map[string]*CommandInfo, len(cmd.val)) + for k, v := range cmd.val { + if v != nil { + newInfo := &CommandInfo{ + Name: v.Name, + Arity: v.Arity, + FirstKeyPos: v.FirstKeyPos, + LastKeyPos: v.LastKeyPos, + StepCount: v.StepCount, + ReadOnly: v.ReadOnly, + Tips: v.Tips, // CommandPolicy can be shared as it's immutable + } + if v.Flags != nil { + newInfo.Flags = make([]string, len(v.Flags)) + copy(newInfo.Flags, v.Flags) + } + if v.ACLFlags != nil { + newInfo.ACLFlags = make([]string, len(v.ACLFlags)) + copy(newInfo.ACLFlags, v.ACLFlags) + } + val[k] = newInfo + } + } + } + return &CommandsInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type cmdsInfoCache struct { @@ -3593,6 +4374,36 @@ func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error return c.cmds, err } +// ------------------------------------------------------------------------------ +const requestPolicy = "request_policy" +const responsePolicy = "response_policy" + +func parseCommandPolicies(commandInfoTips map[string]string) *routing.CommandPolicy { + req := routing.ReqDefault + resp := routing.RespAllSucceeded + + if commandInfoTips != nil { + if v, ok := commandInfoTips[requestPolicy]; ok { + if p, err := routing.ParseRequestPolicy(v); err == nil { + req = p + } + } + if v, ok := commandInfoTips[responsePolicy]; ok { + if p, err := routing.ParseResponsePolicy(v); err == nil { + resp = p + } + } + } + tips := make(map[string]string, len(commandInfoTips)) + for k, v := range commandInfoTips { + if k == requestPolicy || k == responsePolicy { + continue + } + tips[k] = v + } + return &routing.CommandPolicy{Request: req, Response: resp, Tips: tips} +} + //------------------------------------------------------------------------------ type SlowLog struct { @@ -3617,8 +4428,9 @@ var _ Cmder = (*SlowLogCmd)(nil) func NewSlowLogCmd(ctx context.Context, args ...interface{}) *SlowLogCmd { return &SlowLogCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeSlowLog, }, } } @@ -3703,6 +4515,30 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *SlowLogCmd) Clone() Cmder { + var val []SlowLog + if cmd.val != nil { + val = make([]SlowLog, len(cmd.val)) + for i, log := range cmd.val { + val[i] = SlowLog{ + ID: log.ID, + Time: log.Time, + Duration: log.Duration, + ClientAddr: log.ClientAddr, + ClientName: log.ClientName, + } + if log.Args != nil { + val[i].Args = make([]string, len(log.Args)) + copy(val[i].Args, log.Args) + } + } + } + return &SlowLogCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringInterfaceCmd struct { @@ -3716,8 +4552,9 @@ var _ Cmder = (*MapStringInterfaceCmd)(nil) func NewMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceCmd { return &MapStringInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterface, }, } } @@ -3767,6 +4604,20 @@ func (cmd *MapStringInterfaceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringInterfaceCmd) Clone() Cmder { + var val map[string]interface{} + if cmd.val != nil { + val = make(map[string]interface{}, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringStringSliceCmd struct { @@ -3780,8 +4631,9 @@ var _ Cmder = (*MapStringStringSliceCmd)(nil) func NewMapStringStringSliceCmd(ctx context.Context, args ...interface{}) *MapStringStringSliceCmd { return &MapStringStringSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringStringSlice, }, } } @@ -3831,6 +4683,25 @@ func (cmd *MapStringStringSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringStringSliceCmd) Clone() Cmder { + var val []map[string]string + if cmd.val != nil { + val = make([]map[string]string, len(cmd.val)) + for i, m := range cmd.val { + if m != nil { + val[i] = make(map[string]string, len(m)) + for k, v := range m { + val[i][k] = v + } + } + } + } + return &MapStringStringSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ----------------------------------------------------------------------- // MapMapStringInterfaceCmd represents a command that returns a map of strings to interface{}. @@ -3842,8 +4713,9 @@ type MapMapStringInterfaceCmd struct { func NewMapMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapMapStringInterfaceCmd { return &MapMapStringInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapMapStringInterface, }, } } @@ -3909,6 +4781,20 @@ func (cmd *MapMapStringInterfaceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *MapMapStringInterfaceCmd) Clone() Cmder { + var val map[string]interface{} + if cmd.val != nil { + val = make(map[string]interface{}, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapMapStringInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringInterfaceSliceCmd struct { @@ -3922,8 +4808,9 @@ var _ Cmder = (*MapStringInterfaceSliceCmd)(nil) func NewMapStringInterfaceSliceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceSliceCmd { return &MapStringInterfaceSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterfaceSlice, }, } } @@ -3974,6 +4861,25 @@ func (cmd *MapStringInterfaceSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringInterfaceSliceCmd) Clone() Cmder { + var val []map[string]interface{} + if cmd.val != nil { + val = make([]map[string]interface{}, len(cmd.val)) + for i, m := range cmd.val { + if m != nil { + val[i] = make(map[string]interface{}, len(m)) + for k, v := range m { + val[i][k] = v + } + } + } + } + return &MapStringInterfaceSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type KeyValuesCmd struct { @@ -3988,8 +4894,9 @@ var _ Cmder = (*KeyValuesCmd)(nil) func NewKeyValuesCmd(ctx context.Context, args ...interface{}) *KeyValuesCmd { return &KeyValuesCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyValues, }, } } @@ -4036,6 +4943,19 @@ func (cmd *KeyValuesCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *KeyValuesCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &KeyValuesCmd{ + baseCmd: cmd.cloneBaseCmd(), + key: cmd.key, + val: val, + } +} + //------------------------------------------------------------------------------ type ZSliceWithKeyCmd struct { @@ -4050,8 +4970,9 @@ var _ Cmder = (*ZSliceWithKeyCmd)(nil) func NewZSliceWithKeyCmd(ctx context.Context, args ...interface{}) *ZSliceWithKeyCmd { return &ZSliceWithKeyCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZSliceWithKey, }, } } @@ -4119,6 +5040,19 @@ func (cmd *ZSliceWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ZSliceWithKeyCmd) Clone() Cmder { + var val []Z + if cmd.val != nil { + val = make([]Z, len(cmd.val)) + copy(val, cmd.val) + } + return &ZSliceWithKeyCmd{ + baseCmd: cmd.cloneBaseCmd(), + key: cmd.key, + val: val, + } +} + type Function struct { Name string Description string @@ -4143,8 +5077,9 @@ var _ Cmder = (*FunctionListCmd)(nil) func NewFunctionListCmd(ctx context.Context, args ...interface{}) *FunctionListCmd { return &FunctionListCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFunctionList, }, } } @@ -4271,6 +5206,37 @@ func (cmd *FunctionListCmd) readFunctions(rd *proto.Reader) ([]Function, error) return functions, nil } +func (cmd *FunctionListCmd) Clone() Cmder { + var val []Library + if cmd.val != nil { + val = make([]Library, len(cmd.val)) + for i, lib := range cmd.val { + val[i] = Library{ + Name: lib.Name, + Engine: lib.Engine, + Code: lib.Code, + } + if lib.Functions != nil { + val[i].Functions = make([]Function, len(lib.Functions)) + for j, fn := range lib.Functions { + val[i].Functions[j] = Function{ + Name: fn.Name, + Description: fn.Description, + } + if fn.Flags != nil { + val[i].Functions[j].Flags = make([]string, len(fn.Flags)) + copy(val[i].Functions[j].Flags, fn.Flags) + } + } + } + } + } + return &FunctionListCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FunctionStats contains information about the scripts currently executing on the server, and the available engines // - Engines: // Statistics about the engine like number of functions and number of libraries @@ -4324,8 +5290,9 @@ var _ Cmder = (*FunctionStatsCmd)(nil) func NewFunctionStatsCmd(ctx context.Context, args ...interface{}) *FunctionStatsCmd { return &FunctionStatsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFunctionStats, }, } } @@ -4496,6 +5463,34 @@ func (cmd *FunctionStatsCmd) readRunningScripts(rd *proto.Reader) ([]RunningScri return runningScripts, len(runningScripts) > 0, nil } +func (cmd *FunctionStatsCmd) Clone() Cmder { + val := FunctionStats{ + isRunning: cmd.val.isRunning, + rs: cmd.val.rs, // RunningScript is a simple struct, can be copied directly + } + if cmd.val.Engines != nil { + val.Engines = make([]Engine, len(cmd.val.Engines)) + copy(val.Engines, cmd.val.Engines) + } + if cmd.val.allrs != nil { + val.allrs = make([]RunningScript, len(cmd.val.allrs)) + for i, rs := range cmd.val.allrs { + val.allrs[i] = RunningScript{ + Name: rs.Name, + Duration: rs.Duration, + } + if rs.Command != nil { + val.allrs[i].Command = make([]string, len(rs.Command)) + copy(val.allrs[i].Command, rs.Command) + } + } + } + return &FunctionStatsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // LCSQuery is a parameter used for the LCS command @@ -4559,8 +5554,9 @@ func NewLCSCmd(ctx context.Context, q *LCSQuery) *LCSCmd { } } cmd.baseCmd = baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeLCS, } return cmd @@ -4672,6 +5668,25 @@ func (cmd *LCSCmd) readPosition(rd *proto.Reader) (pos LCSPosition, err error) { return pos, nil } +func (cmd *LCSCmd) Clone() Cmder { + var val *LCSMatch + if cmd.val != nil { + val = &LCSMatch{ + MatchString: cmd.val.MatchString, + Len: cmd.val.Len, + } + if cmd.val.Matches != nil { + val.Matches = make([]LCSMatchedPosition, len(cmd.val.Matches)) + copy(val.Matches, cmd.val.Matches) + } + } + return &LCSCmd{ + baseCmd: cmd.cloneBaseCmd(), + readType: cmd.readType, + val: val, + } +} + // ------------------------------------------------------------------------ type KeyFlags struct { @@ -4690,8 +5705,9 @@ var _ Cmder = (*KeyFlagsCmd)(nil) func NewKeyFlagsCmd(ctx context.Context, args ...interface{}) *KeyFlagsCmd { return &KeyFlagsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyFlags, }, } } @@ -4750,6 +5766,26 @@ func (cmd *KeyFlagsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *KeyFlagsCmd) Clone() Cmder { + var val []KeyFlags + if cmd.val != nil { + val = make([]KeyFlags, len(cmd.val)) + for i, kf := range cmd.val { + val[i] = KeyFlags{ + Key: kf.Key, + } + if kf.Flags != nil { + val[i].Flags = make([]string, len(kf.Flags)) + copy(val[i].Flags, kf.Flags) + } + } + } + return &KeyFlagsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // --------------------------------------------------------------------------------------------------- type ClusterLink struct { @@ -4772,8 +5808,9 @@ var _ Cmder = (*ClusterLinksCmd)(nil) func NewClusterLinksCmd(ctx context.Context, args ...interface{}) *ClusterLinksCmd { return &ClusterLinksCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterLinks, }, } } @@ -4839,6 +5876,18 @@ func (cmd *ClusterLinksCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterLinksCmd) Clone() Cmder { + var val []ClusterLink + if cmd.val != nil { + val = make([]ClusterLink, len(cmd.val)) + copy(val, cmd.val) + } + return &ClusterLinksCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------------------------------------------------------------------------------ type SlotRange struct { @@ -4874,8 +5923,9 @@ var _ Cmder = (*ClusterShardsCmd)(nil) func NewClusterShardsCmd(ctx context.Context, args ...interface{}) *ClusterShardsCmd { return &ClusterShardsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterShards, }, } } @@ -4989,6 +6039,28 @@ func (cmd *ClusterShardsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterShardsCmd) Clone() Cmder { + var val []ClusterShard + if cmd.val != nil { + val = make([]ClusterShard, len(cmd.val)) + for i, shard := range cmd.val { + val[i] = ClusterShard{} + if shard.Slots != nil { + val[i].Slots = make([]SlotRange, len(shard.Slots)) + copy(val[i].Slots, shard.Slots) + } + if shard.Nodes != nil { + val[i].Nodes = make([]Node, len(shard.Nodes)) + copy(val[i].Nodes, shard.Nodes) + } + } + } + return &ClusterShardsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ----------------------------------------- type RankScore struct { @@ -5007,8 +6079,9 @@ var _ Cmder = (*RankWithScoreCmd)(nil) func NewRankWithScoreCmd(ctx context.Context, args ...interface{}) *RankWithScoreCmd { return &RankWithScoreCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeRankWithScore, }, } } @@ -5049,6 +6122,13 @@ func (cmd *RankWithScoreCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *RankWithScoreCmd) Clone() Cmder { + return &RankWithScoreCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // RankScore is a simple struct, can be copied directly + } +} + // -------------------------------------------------------------------------------------------------- // ClientFlags is redis-server client flags, copy from redis/src/server.h (redis 7.0) @@ -5155,8 +6235,9 @@ var _ Cmder = (*ClientInfoCmd)(nil) func NewClientInfoCmd(ctx context.Context, args ...interface{}) *ClientInfoCmd { return &ClientInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClientInfo, }, } } @@ -5327,6 +6408,50 @@ func parseClientInfo(txt string) (info *ClientInfo, err error) { return info, nil } +func (cmd *ClientInfoCmd) Clone() Cmder { + var val *ClientInfo + if cmd.val != nil { + val = &ClientInfo{ + ID: cmd.val.ID, + Addr: cmd.val.Addr, + LAddr: cmd.val.LAddr, + FD: cmd.val.FD, + Name: cmd.val.Name, + Age: cmd.val.Age, + Idle: cmd.val.Idle, + Flags: cmd.val.Flags, + DB: cmd.val.DB, + Sub: cmd.val.Sub, + PSub: cmd.val.PSub, + SSub: cmd.val.SSub, + Multi: cmd.val.Multi, + Watch: cmd.val.Watch, + QueryBuf: cmd.val.QueryBuf, + QueryBufFree: cmd.val.QueryBufFree, + ArgvMem: cmd.val.ArgvMem, + MultiMem: cmd.val.MultiMem, + BufferSize: cmd.val.BufferSize, + BufferPeak: cmd.val.BufferPeak, + OutputBufferLength: cmd.val.OutputBufferLength, + OutputListLength: cmd.val.OutputListLength, + OutputMemory: cmd.val.OutputMemory, + TotalMemory: cmd.val.TotalMemory, + IoThread: cmd.val.IoThread, + Events: cmd.val.Events, + LastCmd: cmd.val.LastCmd, + User: cmd.val.User, + Redir: cmd.val.Redir, + Resp: cmd.val.Resp, + LibName: cmd.val.LibName, + LibVer: cmd.val.LibVer, + } + } + return &ClientInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------- type ACLLogEntry struct { @@ -5353,8 +6478,9 @@ var _ Cmder = (*ACLLogCmd)(nil) func NewACLLogCmd(ctx context.Context, args ...interface{}) *ACLLogCmd { return &ACLLogCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeACLLog, }, } } @@ -5436,6 +6562,69 @@ func (cmd *ACLLogCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ACLLogCmd) Clone() Cmder { + var val []*ACLLogEntry + if cmd.val != nil { + val = make([]*ACLLogEntry, len(cmd.val)) + for i, entry := range cmd.val { + if entry != nil { + val[i] = &ACLLogEntry{ + Count: entry.Count, + Reason: entry.Reason, + Context: entry.Context, + Object: entry.Object, + Username: entry.Username, + AgeSeconds: entry.AgeSeconds, + EntryID: entry.EntryID, + TimestampCreated: entry.TimestampCreated, + TimestampLastUpdated: entry.TimestampLastUpdated, + } + // Clone ClientInfo if present + if entry.ClientInfo != nil { + val[i].ClientInfo = &ClientInfo{ + ID: entry.ClientInfo.ID, + Addr: entry.ClientInfo.Addr, + LAddr: entry.ClientInfo.LAddr, + FD: entry.ClientInfo.FD, + Name: entry.ClientInfo.Name, + Age: entry.ClientInfo.Age, + Idle: entry.ClientInfo.Idle, + Flags: entry.ClientInfo.Flags, + DB: entry.ClientInfo.DB, + Sub: entry.ClientInfo.Sub, + PSub: entry.ClientInfo.PSub, + SSub: entry.ClientInfo.SSub, + Multi: entry.ClientInfo.Multi, + Watch: entry.ClientInfo.Watch, + QueryBuf: entry.ClientInfo.QueryBuf, + QueryBufFree: entry.ClientInfo.QueryBufFree, + ArgvMem: entry.ClientInfo.ArgvMem, + MultiMem: entry.ClientInfo.MultiMem, + BufferSize: entry.ClientInfo.BufferSize, + BufferPeak: entry.ClientInfo.BufferPeak, + OutputBufferLength: entry.ClientInfo.OutputBufferLength, + OutputListLength: entry.ClientInfo.OutputListLength, + OutputMemory: entry.ClientInfo.OutputMemory, + TotalMemory: entry.ClientInfo.TotalMemory, + IoThread: entry.ClientInfo.IoThread, + Events: entry.ClientInfo.Events, + LastCmd: entry.ClientInfo.LastCmd, + User: entry.ClientInfo.User, + Redir: entry.ClientInfo.Redir, + Resp: entry.ClientInfo.Resp, + LibName: entry.ClientInfo.LibName, + LibVer: entry.ClientInfo.LibVer, + } + } + } + } + } + return &ACLLogCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // LibraryInfo holds the library info. type LibraryInfo struct { LibName *string @@ -5464,8 +6653,9 @@ var _ Cmder = (*InfoCmd)(nil) func NewInfoCmd(ctx context.Context, args ...interface{}) *InfoCmd { return &InfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeInfo, }, } } @@ -5531,6 +6721,25 @@ func (cmd *InfoCmd) Item(section, key string) string { } } +func (cmd *InfoCmd) Clone() Cmder { + var val map[string]map[string]string + if cmd.val != nil { + val = make(map[string]map[string]string, len(cmd.val)) + for section, sectionMap := range cmd.val { + if sectionMap != nil { + val[section] = make(map[string]string, len(sectionMap)) + for k, v := range sectionMap { + val[section][k] = v + } + } + } + } + return &InfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + type MonitorStatus int const ( @@ -5549,8 +6758,9 @@ type MonitorCmd struct { func newMonitorCmd(ctx context.Context, ch chan string) *MonitorCmd { return &MonitorCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"monitor"}, + ctx: ctx, + args: []interface{}{"monitor"}, + cmdType: CmdTypeMonitor, }, ch: ch, status: monitorStatusIdle, @@ -5615,3 +6825,292 @@ func (cmd *MonitorCmd) Stop() { defer cmd.mu.Unlock() cmd.status = monitorStatusStop } + +// ExtractCommandValue extracts the value from a command result using the fast enum-based approach +func ExtractCommandValue(cmd interface{}) interface{} { + // First try to get the command type using the interface + if cmdTypeGetter, ok := cmd.(CmdTypeGetter); ok { + cmdType := cmdTypeGetter.GetCmdType() + + // Use fast type-based extraction + switch cmdType { + case CmdTypeString: + if stringCmd, ok := cmd.(interface{ Val() string }); ok { + return stringCmd.Val() + } + case CmdTypeInt: + if intCmd, ok := cmd.(interface{ Val() int64 }); ok { + return intCmd.Val() + } + case CmdTypeBool: + if boolCmd, ok := cmd.(interface{ Val() bool }); ok { + return boolCmd.Val() + } + case CmdTypeFloat: + if floatCmd, ok := cmd.(interface{ Val() float64 }); ok { + return floatCmd.Val() + } + case CmdTypeStatus: + if statusCmd, ok := cmd.(interface{ Val() string }); ok { + return statusCmd.Val() + } + case CmdTypeDuration: + if durationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return durationCmd.Val() + } + case CmdTypeTime: + if timeCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return timeCmd.Val() + } + case CmdTypeStringSlice: + if stringSliceCmd, ok := cmd.(interface{ Val() []string }); ok { + return stringSliceCmd.Val() + } + case CmdTypeIntSlice: + if intSliceCmd, ok := cmd.(interface{ Val() []int64 }); ok { + return intSliceCmd.Val() + } + case CmdTypeBoolSlice: + if boolSliceCmd, ok := cmd.(interface{ Val() []bool }); ok { + return boolSliceCmd.Val() + } + case CmdTypeFloatSlice: + if floatSliceCmd, ok := cmd.(interface{ Val() []float64 }); ok { + return floatSliceCmd.Val() + } + case CmdTypeMapStringString: + if mapCmd, ok := cmd.(interface{ Val() map[string]string }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInt: + if mapCmd, ok := cmd.(interface{ Val() map[string]int64 }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInterfaceSlice: + if mapCmd, ok := cmd.(interface { + Val() map[string][]interface{} + }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringInterface: + if mapCmd, ok := cmd.(interface{ Val() map[string]interface{} }); ok { + return mapCmd.Val() + } + case CmdTypeMapStringStringSlice: + if mapCmd, ok := cmd.(interface{ Val() map[string][]string }); ok { + return mapCmd.Val() + } + case CmdTypeMapMapStringInterface: + if mapCmd, ok := cmd.(interface { + Val() map[string][]interface{} + }); ok { + return mapCmd.Val() + } + case CmdTypeStringStructMap: + if mapCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return mapCmd.Val() + } + case CmdTypeXMessageSlice: + if xMsgCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xMsgCmd.Val() + } + case CmdTypeXStreamSlice: + if xStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xStreamCmd.Val() + } + case CmdTypeXPending: + if xPendingCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xPendingCmd.Val() + } + case CmdTypeXPendingExt: + if xPendingExtCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xPendingExtCmd.Val() + } + case CmdTypeXAutoClaim: + if xAutoClaimCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xAutoClaimCmd.Val() + } + case CmdTypeXAutoClaimJustID: + if xAutoClaimJustIDCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xAutoClaimJustIDCmd.Val() + } + case CmdTypeXInfoConsumers: + if xInfoConsumersCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoConsumersCmd.Val() + } + case CmdTypeXInfoGroups: + if xInfoGroupsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoGroupsCmd.Val() + } + case CmdTypeXInfoStream: + if xInfoStreamCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoStreamCmd.Val() + } + case CmdTypeXInfoStreamFull: + if xInfoStreamFullCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return xInfoStreamFullCmd.Val() + } + case CmdTypeZSlice: + if zSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zSliceCmd.Val() + } + case CmdTypeZWithKey: + if zWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zWithKeyCmd.Val() + } + case CmdTypeScan: + if scanCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return scanCmd.Val() + } + case CmdTypeClusterSlots: + if clusterSlotsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterSlotsCmd.Val() + } + case CmdTypeGeoSearchLocation: + if geoSearchLocationCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return geoSearchLocationCmd.Val() + } + case CmdTypeGeoPos: + if geoPosCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return geoPosCmd.Val() + } + case CmdTypeCommandsInfo: + if commandsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return commandsInfoCmd.Val() + } + case CmdTypeSlowLog: + if slowLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return slowLogCmd.Val() + } + + case CmdTypeKeyValues: + if keyValuesCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return keyValuesCmd.Val() + } + case CmdTypeZSliceWithKey: + if zSliceWithKeyCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return zSliceWithKeyCmd.Val() + } + case CmdTypeFunctionList: + if functionListCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return functionListCmd.Val() + } + case CmdTypeFunctionStats: + if functionStatsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return functionStatsCmd.Val() + } + case CmdTypeLCS: + if lcsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return lcsCmd.Val() + } + case CmdTypeKeyFlags: + if keyFlagsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return keyFlagsCmd.Val() + } + case CmdTypeClusterLinks: + if clusterLinksCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterLinksCmd.Val() + } + case CmdTypeClusterShards: + if clusterShardsCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clusterShardsCmd.Val() + } + case CmdTypeRankWithScore: + if rankWithScoreCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return rankWithScoreCmd.Val() + } + case CmdTypeClientInfo: + if clientInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return clientInfoCmd.Val() + } + case CmdTypeACLLog: + if aclLogCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return aclLogCmd.Val() + } + case CmdTypeInfo: + if infoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return infoCmd.Val() + } + case CmdTypeMonitor: + if monitorCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return monitorCmd.Val() + } + case CmdTypeJSON: + if jsonCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return jsonCmd.Val() + } + case CmdTypeJSONSlice: + if jsonSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return jsonSliceCmd.Val() + } + case CmdTypeIntPointerSlice: + if intPointerSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return intPointerSliceCmd.Val() + } + case CmdTypeScanDump: + if scanDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return scanDumpCmd.Val() + } + case CmdTypeBFInfo: + if bfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return bfInfoCmd.Val() + } + case CmdTypeCFInfo: + if cfInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return cfInfoCmd.Val() + } + case CmdTypeCMSInfo: + if cmsInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return cmsInfoCmd.Val() + } + case CmdTypeTopKInfo: + if topKInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return topKInfoCmd.Val() + } + case CmdTypeTDigestInfo: + if tDigestInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tDigestInfoCmd.Val() + } + case CmdTypeFTSearch: + if ftSearchCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSearchCmd.Val() + } + case CmdTypeFTInfo: + if ftInfoCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftInfoCmd.Val() + } + case CmdTypeFTSpellCheck: + if ftSpellCheckCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSpellCheckCmd.Val() + } + case CmdTypeFTSynDump: + if ftSynDumpCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return ftSynDumpCmd.Val() + } + case CmdTypeAggregate: + if aggregateCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return aggregateCmd.Val() + } + case CmdTypeTSTimestampValue: + if tsTimestampValueCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tsTimestampValueCmd.Val() + } + case CmdTypeTSTimestampValueSlice: + if tsTimestampValueSliceCmd, ok := cmd.(interface{ Val() interface{} }); ok { + return tsTimestampValueSliceCmd.Val() + } + default: + // For unknown command types, return nil + return nil + } + } + + // If we can't get the command type, return nil + return nil +} + +func (cmd *MonitorCmd) Clone() Cmder { + // MonitorCmd cannot be safely cloned due to channels and goroutines + // Return a new MonitorCmd with the same channel + return newMonitorCmd(cmd.ctx, cmd.ch) +} diff --git a/commands_test.go b/commands_test.go index 8b2aa37d4..c516a277a 100644 --- a/commands_test.go +++ b/commands_test.go @@ -13,6 +13,7 @@ import ( "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/routing" ) type TimeValue struct { @@ -657,6 +658,22 @@ var _ = Describe("Commands", func() { Expect(cmd.StepCount).To(Equal(int8(0))) }) + It("should Command Tips", Label("NonRedisEnterprise"), func() { + SkipAfterRedisVersion(7.9, "Redis 8 changed the COMMAND reply format") + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + cmd := cmds["touch"] + Expect(cmd.Name).To(Equal("touch")) + Expect(cmd.Tips.Request).To(Equal(routing.ReqMultiShard)) + Expect(cmd.Tips.Response).To(Equal(routing.RespAggSum)) + + cmd = cmds["flushall"] + Expect(cmd.Name).To(Equal("flushall")) + Expect(cmd.Tips.Request).To(Equal(routing.ReqAllShards)) + Expect(cmd.Tips.Response).To(Equal(routing.RespAllSucceeded)) + }) + It("should return all command names", func() { cmdList := client.CommandList(ctx, nil) Expect(cmdList.Err()).NotTo(HaveOccurred()) diff --git a/internal/routing/aggregator.go b/internal/routing/aggregator.go new file mode 100644 index 000000000..962e59264 --- /dev/null +++ b/internal/routing/aggregator.go @@ -0,0 +1,589 @@ +package routing + +import ( + "fmt" + "math" + "sync" +) + +// ResponseAggregator defines the interface for aggregating responses from multiple shards. +type ResponseAggregator interface { + // Add processes a single shard response. + Add(result interface{}, err error) error + + // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). + AddWithKey(key string, result interface{}, err error) error + + // Finish returns the final aggregated result and any error. + Finish() (interface{}, error) +} + +// NewResponseAggregator creates an aggregator based on the response policy. +func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggregator { + switch policy { + case RespDefaultKeyless: + return &DefaultKeylessAggregator{} + case RespDefaultHashSlot: + return &DefaultKeyedAggregator{} + case RespAllSucceeded: + return &AllSucceededAggregator{} + case RespOneSucceeded: + return &OneSucceededAggregator{} + case RespAggSum: + return &AggSumAggregator{} + case RespAggMin: + return &AggMinAggregator{} + case RespAggMax: + return &AggMaxAggregator{} + case RespAggLogicalAnd: + return &AggLogicalAndAggregator{} + case RespAggLogicalOr: + return &AggLogicalOrAggregator{} + case RespSpecial: + return NewSpecialAggregator(cmdName) + default: + return &AllSucceededAggregator{} + } +} + +func NewDefaultAggregator(isKeyed bool) ResponseAggregator { + if isKeyed { + return &DefaultKeyedAggregator{ + results: make(map[string]interface{}), + } + } + return &DefaultKeylessAggregator{} +} + +// AllSucceededAggregator returns one non-error reply if every shard succeeded, +// propagates the first error otherwise. +type AllSucceededAggregator struct { + mu sync.Mutex + result interface{} + firstErr error + hasResult bool +} + +func (a *AllSucceededAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil && !a.hasResult { + a.result = result + a.hasResult = true + } + return nil +} + +func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AllSucceededAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.result, nil +} + +// OneSucceededAggregator returns the first non-error reply, +// if all shards errored, returns any one of those errors. +type OneSucceededAggregator struct { + mu sync.Mutex + result interface{} + firstErr error + hasResult bool +} + +func (a *OneSucceededAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil && !a.hasResult { + a.result = result + a.hasResult = true + } + return nil +} + +func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *OneSucceededAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.hasResult { + return a.result, nil + } + return nil, a.firstErr +} + +// AggSumAggregator sums numeric replies from all shards. +type AggSumAggregator struct { + mu sync.Mutex + sum int64 + hasResult bool + firstErr error +} + +func (a *AggSumAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toInt64(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.sum += val + a.hasResult = true + } + } + return nil +} + +func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggSumAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.sum, nil +} + +// AggMinAggregator returns the minimum numeric value from all shards. +type AggMinAggregator struct { + mu sync.Mutex + min int64 + hasResult bool + firstErr error +} + +func (a *AggMinAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toInt64(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult || val < a.min { + a.min = val + a.hasResult = true + } + } + } + return nil +} + +func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggMinAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for min operation") + } + return a.min, nil +} + +// AggMaxAggregator returns the maximum numeric value from all shards. +type AggMaxAggregator struct { + mu sync.Mutex + max int64 + hasResult bool + firstErr error +} + +func (a *AggMaxAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toInt64(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult || val > a.max { + a.max = val + a.hasResult = true + } + } + } + return nil +} + +func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggMaxAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for max operation") + } + return a.max, nil +} + +// AggLogicalAndAggregator performs logical AND on boolean values. +type AggLogicalAndAggregator struct { + mu sync.Mutex + result bool + hasResult bool + firstErr error +} + +func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toBool(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult { + a.result = val + a.hasResult = true + } else { + a.result = a.result && val + } + } + } + return nil +} + +func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggLogicalAndAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for logical AND operation") + } + return a.result, nil +} + +// AggLogicalOrAggregator performs logical OR on boolean values. +type AggLogicalOrAggregator struct { + mu sync.Mutex + result bool + hasResult bool + firstErr error +} + +func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + val, err := toBool(result) + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + if !a.hasResult { + a.result = val + a.hasResult = true + } else { + a.result = a.result || val + } + } + } + return nil +} + +func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggLogicalOrAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + if !a.hasResult { + return nil, fmt.Errorf("redis: no valid results to aggregate for logical OR operation") + } + return a.result, nil +} + +func toInt64(val interface{}) (int64, error) { + if val == nil { + return 0, nil + } + switch v := val.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case int32: + return int64(v), nil + case float64: + if v != math.Trunc(v) { + return 0, fmt.Errorf("cannot convert float %f to int64", v) + } + return int64(v), nil + default: + return 0, fmt.Errorf("cannot convert %T to int64", val) + } +} + +func toBool(val interface{}) (bool, error) { + if val == nil { + return false, nil + } + switch v := val.(type) { + case bool: + return v, nil + case int64: + return v != 0, nil + case int: + return v != 0, nil + default: + return false, fmt.Errorf("cannot convert %T to bool", val) + } +} + +// DefaultKeylessAggregator collects all results in an array, order doesn't matter. +type DefaultKeylessAggregator struct { + mu sync.Mutex + results []interface{} + firstErr error +} + +func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.results = append(a.results, result) + } + return nil +} + +func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *DefaultKeylessAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.results, nil +} + +// DefaultKeyedAggregator reassembles replies in the exact key order of the original request. +type DefaultKeyedAggregator struct { + mu sync.Mutex + results map[string]interface{} + keyOrder []string + firstErr error +} + +func NewDefaultKeyedAggregator(keyOrder []string) *DefaultKeyedAggregator { + return &DefaultKeyedAggregator{ + results: make(map[string]interface{}), + keyOrder: keyOrder, + } +} + +func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + // For non-keyed Add, just collect the result without ordering + if err == nil { + a.results["__default__"] = result + } + return nil +} + +func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.results[key] = result + } + return nil +} + +func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { + a.mu.Lock() + defer a.mu.Unlock() + a.keyOrder = keyOrder +} + +func (a *DefaultKeyedAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + + // If no explicit key order is set, return results in any order + if len(a.keyOrder) == 0 { + orderedResults := make([]interface{}, 0, len(a.results)) + for _, result := range a.results { + orderedResults = append(orderedResults, result) + } + return orderedResults, nil + } + + // Return results in the exact key order + orderedResults := make([]interface{}, len(a.keyOrder)) + for i, key := range a.keyOrder { + if result, exists := a.results[key]; exists { + orderedResults[i] = result + } + } + return orderedResults, nil +} + +// SpecialAggregator provides a registry for command-specific aggregation logic. +type SpecialAggregator struct { + mu sync.Mutex + aggregatorFunc func([]interface{}, []error) (interface{}, error) + results []interface{} + errors []error +} + +func (a *SpecialAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + a.results = append(a.results, result) + a.errors = append(a.errors, err) + return nil +} + +func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *SpecialAggregator) Finish() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.aggregatorFunc != nil { + return a.aggregatorFunc(a.results, a.errors) + } + // Default behavior: return first non-error result or first error + for i, err := range a.errors { + if err == nil { + return a.results[i], nil + } + } + if len(a.errors) > 0 { + return nil, a.errors[0] + } + return nil, nil +} + +// SetAggregatorFunc allows setting custom aggregation logic for special commands. +func (a *SpecialAggregator) SetAggregatorFunc(fn func([]interface{}, []error) (interface{}, error)) { + a.mu.Lock() + defer a.mu.Unlock() + a.aggregatorFunc = fn +} + +// SpecialAggregatorRegistry holds custom aggregation functions for specific commands. +var SpecialAggregatorRegistry = make(map[string]func([]interface{}, []error) (interface{}, error)) + +// RegisterSpecialAggregator registers a custom aggregation function for a command. +func RegisterSpecialAggregator(cmdName string, fn func([]interface{}, []error) (interface{}, error)) { + SpecialAggregatorRegistry[cmdName] = fn +} + +// NewSpecialAggregator creates a special aggregator with command-specific logic if available. +func NewSpecialAggregator(cmdName string) *SpecialAggregator { + agg := &SpecialAggregator{} + if fn, exists := SpecialAggregatorRegistry[cmdName]; exists { + agg.SetAggregatorFunc(fn) + } + return agg +} diff --git a/internal/routing/policy.go b/internal/routing/policy.go new file mode 100644 index 000000000..a76dfaf19 --- /dev/null +++ b/internal/routing/policy.go @@ -0,0 +1,135 @@ +package routing + +import ( + "fmt" + "strings" +) + +type RequestPolicy uint8 + +const ( + ReqDefault RequestPolicy = iota + + ReqAllNodes + + ReqAllShards + + ReqMultiShard + + ReqSpecial +) + +func (p RequestPolicy) String() string { + switch p { + case ReqDefault: + return "default" + case ReqAllNodes: + return "all_nodes" + case ReqAllShards: + return "all_shards" + case ReqMultiShard: + return "multi_shard" + case ReqSpecial: + return "special" + default: + return fmt.Sprintf("unknown_request_policy(%d)", p) + } +} + +func ParseRequestPolicy(raw string) (RequestPolicy, error) { + switch strings.ToLower(raw) { + case "", "default", "none": + return ReqDefault, nil + case "all_nodes": + return ReqAllNodes, nil + case "all_shards": + return ReqAllShards, nil + case "multi_shard": + return ReqMultiShard, nil + case "special": + return ReqSpecial, nil + default: + return ReqDefault, fmt.Errorf("routing: unknown request_policy %q", raw) + } +} + +type ResponsePolicy uint8 + +const ( + RespDefaultKeyless ResponsePolicy = iota + RespDefaultHashSlot + RespAllSucceeded + RespOneSucceeded + RespAggSum + RespAggMin + RespAggMax + RespAggLogicalAnd + RespAggLogicalOr + RespSpecial +) + +func (p ResponsePolicy) String() string { + switch p { + case RespDefaultKeyless: + return "default(keyless)" + case RespDefaultHashSlot: + return "default(hashslot)" + case RespAllSucceeded: + return "all_succeeded" + case RespOneSucceeded: + return "one_succeeded" + case RespAggSum: + return "agg_sum" + case RespAggMin: + return "agg_min" + case RespAggMax: + return "agg_max" + case RespAggLogicalAnd: + return "agg_logical_and" + case RespAggLogicalOr: + return "agg_logical_or" + case RespSpecial: + return "special" + default: + return "all_succeeded" + } +} + +func ParseResponsePolicy(raw string) (ResponsePolicy, error) { + switch strings.ToLower(raw) { + case "default(keyless)": + return RespDefaultKeyless, nil + case "default(hashslot)": + return RespDefaultHashSlot, nil + case "all_succeeded": + return RespAllSucceeded, nil + case "one_succeeded": + return RespOneSucceeded, nil + case "agg_sum": + return RespAggSum, nil + case "agg_min": + return RespAggMin, nil + case "agg_max": + return RespAggMax, nil + case "agg_logical_and": + return RespAggLogicalAnd, nil + case "agg_logical_or": + return RespAggLogicalOr, nil + case "special": + return RespSpecial, nil + default: + return RespDefaultKeyless, fmt.Errorf("routing: unknown response_policy %q", raw) + } +} + +type CommandPolicy struct { + Request RequestPolicy + Response ResponsePolicy + // Tips that are not request_policy or response_policy + // e.g nondeterministic_output, nondeterministic_output_order. + Tips map[string]string +} + +func (p *CommandPolicy) CanBeUsedInPipeline() bool { + return p.Request != ReqAllNodes && p.Request != ReqAllShards && p.Request != ReqMultiShard +} diff --git a/internal/routing/shard_picker.go b/internal/routing/shard_picker.go new file mode 100644 index 000000000..e29d526b0 --- /dev/null +++ b/internal/routing/shard_picker.go @@ -0,0 +1,41 @@ +package routing + +import ( + "math/rand" + "sync/atomic" +) + +// ShardPicker chooses “one arbitrary shard” when the request_policy is +// ReqDefault and the command has no keys. +type ShardPicker interface { + Next(total int) int // returns an index in [0,total) +} + +/*─────────────────────────────── + Round-robin (default) +────────────────────────────────*/ + +type RoundRobinPicker struct { + cnt atomic.Uint32 +} + +func (p *RoundRobinPicker) Next(total int) int { + if total == 0 { + return 0 + } + i := p.cnt.Add(1) + return int(i-1) % total +} + +/*─────────────────────────────── + Random +────────────────────────────────*/ + +type RandomPicker struct{} + +func (RandomPicker) Next(total int) int { + if total == 0 { + return 0 + } + return rand.Intn(total) +} diff --git a/json.go b/json.go index b3cadf4b7..d738e397d 100644 --- a/json.go +++ b/json.go @@ -68,8 +68,9 @@ var _ Cmder = (*JSONCmd)(nil) func newJSONCmd(ctx context.Context, args ...interface{}) *JSONCmd { return &JSONCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeJSON, }, } } @@ -149,6 +150,14 @@ func (cmd *JSONCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *JSONCmd) Clone() Cmder { + return &JSONCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + expanded: cmd.expanded, // interface{} can be shared as it should be immutable after parsing + } +} + // ------------------------------------------- type JSONSliceCmd struct { @@ -159,8 +168,9 @@ type JSONSliceCmd struct { func NewJSONSliceCmd(ctx context.Context, args ...interface{}) *JSONSliceCmd { return &JSONSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeJSONSlice, }, } } @@ -217,6 +227,18 @@ func (cmd *JSONSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *JSONSliceCmd) Clone() Cmder { + var val []interface{} + if cmd.val != nil { + val = make([]interface{}, len(cmd.val)) + copy(val, cmd.val) + } + return &JSONSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + /******************************************************************************* * * IntPointerSliceCmd @@ -233,8 +255,9 @@ type IntPointerSliceCmd struct { func NewIntPointerSliceCmd(ctx context.Context, args ...interface{}) *IntPointerSliceCmd { return &IntPointerSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeIntPointerSlice, }, } } @@ -274,6 +297,23 @@ func (cmd *IntPointerSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *IntPointerSliceCmd) Clone() Cmder { + var val []*int64 + if cmd.val != nil { + val = make([]*int64, len(cmd.val)) + for i, ptr := range cmd.val { + if ptr != nil { + newVal := *ptr + val[i] = &newVal + } + } + } + return &IntPointerSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // JSONArrAppend adds the provided JSON values to the end of the array at the given path. diff --git a/main_test.go b/main_test.go index 29e6014b9..a0c9dece8 100644 --- a/main_test.go +++ b/main_test.go @@ -105,6 +105,7 @@ var _ = BeforeSuite(func() { if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") + } redisPort = redisStackPort diff --git a/osscluster.go b/osscluster.go index c0278ed05..fdbbc4d02 100644 --- a/osscluster.go +++ b/osscluster.go @@ -19,6 +19,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/internal/routing" ) const ( @@ -108,6 +109,10 @@ type ClusterOptions struct { // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. UnstableResp3 bool + + // ShardPicker is used to pick a shard when the request_policy is + // ReqDefault and the command has no keys. + ShardPicker routing.ShardPicker } func (opt *ClusterOptions) init() { @@ -158,6 +163,10 @@ func (opt *ClusterOptions) init() { if opt.NewClient == nil { opt.NewClient = NewClient } + + if opt.ShardPicker == nil { + opt.ShardPicker = &routing.RoundRobinPicker{} + } } // ParseClusterURL parses a URL into ClusterOptions that can be used to connect to Redis. @@ -924,9 +933,6 @@ type ClusterClient struct { // NewClusterClient returns a Redis Cluster client as described in // http://redis.io/topics/cluster-spec. func NewClusterClient(opt *ClusterOptions) *ClusterClient { - if opt == nil { - panic("redis: NewClusterClient nil options") - } opt.init() c := &ClusterClient{ @@ -937,7 +943,6 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) c.cmdable = c.Process - c.initHooks(hooks{ dial: nil, process: c.process, @@ -1005,13 +1010,13 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { if ask { ask = false - pipe := node.Client.Pipeline() _ = pipe.Process(ctx, NewCmd(ctx, "asking")) _ = pipe.Process(ctx, cmd) _, lastErr = pipe.Exec(ctx) } else { - lastErr = node.Client.Process(ctx, cmd) + // Execute the command on the selected node + lastErr = c.routeAndRun(ctx, cmd, node) } // If there is no error - we are done. @@ -1329,6 +1334,12 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { + policy := c.getCommandPolicy(ctx, cmd) + if policy != nil && !policy.CanBeUsedInPipeline() { + return fmt.Errorf( + "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), + ) + } slot := c.cmdSlot(ctx, cmd) node, err := c.slotReadOnlyNode(state, slot) if err != nil { @@ -1340,6 +1351,12 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { + policy := c.getCommandPolicy(ctx, cmd) + if policy != nil && !policy.CanBeUsedInPipeline() { + return fmt.Errorf( + "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), + ) + } slot := c.cmdSlot(ctx, cmd) node, err := state.slotMasterNode(slot) if err != nil { @@ -1820,7 +1837,6 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, for _, idx := range perm { addr := addrs[idx] - node, err := c.nodes.GetOrCreate(addr) if err != nil { if firstErr == nil { @@ -1833,6 +1849,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, if err == nil { return info, nil } + if firstErr == nil { firstErr = err } @@ -1845,7 +1862,17 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, } func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get(ctx) + // Use a separate context that won't be canceled to ensure command info lookup + // doesn't fail due to original context cancellation + cmdInfoCtx := context.Background() + if c.opt.ContextTimeoutEnabled && ctx != nil { + // If context timeout is enabled, still use a reasonable timeout + var cancel context.CancelFunc + cmdInfoCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + + cmdsInfo, err := c.cmdsInfoCache.Get(cmdInfoCtx) if err != nil { internal.Logger.Printf(context.TODO(), "getting command info: %s", err) return nil diff --git a/osscluster_router.go b/osscluster_router.go new file mode 100644 index 000000000..a4ddbbf4a --- /dev/null +++ b/osscluster_router.go @@ -0,0 +1,966 @@ +package redis + +import ( + "context" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal/hashtag" + "github.com/redis/go-redis/v9/internal/routing" +) + +// slotResult represents the result of executing a command on a specific slot +type slotResult struct { + cmd Cmder + keys []string + err error +} + +// routeAndRun routes a command to the appropriate cluster nodes and executes it +func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { + policy := c.getCommandPolicy(ctx, cmd) + + switch { + case policy != nil && policy.Request == routing.ReqAllNodes: + return c.executeOnAllNodes(ctx, cmd, policy) + case policy != nil && policy.Request == routing.ReqAllShards: + return c.executeOnAllShards(ctx, cmd, policy) + case policy != nil && policy.Request == routing.ReqMultiShard: + return c.executeMultiShard(ctx, cmd, policy) + case policy != nil && policy.Request == routing.ReqSpecial: + return c.executeSpecialCommand(ctx, cmd, policy, node) + default: + return c.executeDefault(ctx, cmd, node) + } +} + +// getCommandPolicy retrieves the routing policy for a command +func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.Tips != nil { + return cmdInfo.Tips + } + return nil +} + +// executeDefault handles standard command routing based on keys +func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error { + if c.hasKeys(cmd) { + // execute on key based shard + return node.Client.Process(ctx, cmd) + } + return c.executeOnArbitraryShard(ctx, cmd) +} + +// executeOnArbitraryShard routes command to an arbitrary shard +func (c *ClusterClient) executeOnArbitraryShard(ctx context.Context, cmd Cmder) error { + node := c.pickArbitraryShard(ctx) + if node == nil { + return errClusterNoNodes + } + return node.Client.Process(ctx, cmd) +} + +// executeOnAllNodes executes command on all nodes (masters and replicas) +func (c *ClusterClient) executeOnAllNodes(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + nodes := append(state.Masters, state.Slaves...) + if len(nodes) == 0 { + return errClusterNoNodes + } + + return c.executeParallel(ctx, cmd, nodes, policy) +} + +// executeOnAllShards executes command on all master shards +func (c *ClusterClient) executeOnAllShards(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + if len(state.Masters) == 0 { + return errClusterNoNodes + } + + return c.executeParallel(ctx, cmd, state.Masters, policy) +} + +// executeMultiShard handles commands that operate on multiple keys across shards +func (c *ClusterClient) executeMultiShard(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + args := cmd.Args() + firstKeyPos := int(cmdFirstKeyPos(cmd)) + + if firstKeyPos == 0 || firstKeyPos >= len(args) { + return fmt.Errorf("redis: multi-shard command %s has no key arguments", cmd.Name()) + } + + // Group keys by slot + slotMap := make(map[int][]string) + keyOrder := make([]string, 0) + + for i := firstKeyPos; i < len(args); i++ { + key, ok := args[i].(string) + if !ok { + return fmt.Errorf("redis: non-string key at position %d: %v", i, args[i]) + } + + slot := hashtag.Slot(key) + slotMap[slot] = append(slotMap[slot], key) + keyOrder = append(keyOrder, key) + } + + return c.executeMultiSlot(ctx, cmd, slotMap, keyOrder, policy) +} + +// executeMultiSlot executes commands across multiple slots concurrently +func (c *ClusterClient) executeMultiSlot(ctx context.Context, cmd Cmder, slotMap map[int][]string, keyOrder []string, policy *routing.CommandPolicy) error { + results := make(chan slotResult, len(slotMap)) + var wg sync.WaitGroup + + // Execute on each slot concurrently + for slot, keys := range slotMap { + wg.Add(1) + go func(slot int, keys []string) { + defer wg.Done() + + node, err := c.cmdNode(ctx, cmd.Name(), slot) + if err != nil { + results <- slotResult{nil, keys, err} + return + } + + // Create a command for this specific slot's keys + subCmd := c.createSlotSpecificCommand(ctx, cmd, keys) + err = node.Client.Process(ctx, subCmd) + results <- slotResult{subCmd, keys, err} + }(slot, keys) + } + + go func() { + wg.Wait() + close(results) + }() + + return c.aggregateMultiSlotResults(ctx, cmd, results, keyOrder, policy) +} + +// createSlotSpecificCommand creates a new command for a specific slot's keys +func (c *ClusterClient) createSlotSpecificCommand(ctx context.Context, originalCmd Cmder, keys []string) Cmder { + originalArgs := originalCmd.Args() + firstKeyPos := int(cmdFirstKeyPos(originalCmd)) + + // Build new args with only the specified keys + newArgs := make([]interface{}, 0, firstKeyPos+len(keys)) + + // Copy command name and arguments before the keys + newArgs = append(newArgs, originalArgs[:firstKeyPos]...) + + // Add the slot-specific keys + for _, key := range keys { + newArgs = append(newArgs, key) + } + + // Create a new command of the same type using the helper function + return createCommandByType(ctx, originalCmd.GetCmdType(), newArgs...) +} + +// createCommandByType creates a new command of the specified type with the given arguments +func createCommandByType(ctx context.Context, cmdType CmdType, args ...interface{}) Cmder { + switch cmdType { + case CmdTypeString: + return NewStringCmd(ctx, args...) + case CmdTypeInt: + return NewIntCmd(ctx, args...) + case CmdTypeBool: + return NewBoolCmd(ctx, args...) + case CmdTypeFloat: + return NewFloatCmd(ctx, args...) + case CmdTypeStringSlice: + return NewStringSliceCmd(ctx, args...) + case CmdTypeIntSlice: + return NewIntSliceCmd(ctx, args...) + case CmdTypeFloatSlice: + return NewFloatSliceCmd(ctx, args...) + case CmdTypeBoolSlice: + return NewBoolSliceCmd(ctx, args...) + case CmdTypeStatus: + return NewStatusCmd(ctx, args...) + case CmdTypeTime: + return NewTimeCmd(ctx, args...) + case CmdTypeMapStringString: + return NewMapStringStringCmd(ctx, args...) + case CmdTypeMapStringInt: + return NewMapStringIntCmd(ctx, args...) + case CmdTypeMapStringInterface: + return NewMapStringInterfaceCmd(ctx, args...) + case CmdTypeMapStringInterfaceSlice: + return NewMapStringInterfaceSliceCmd(ctx, args...) + case CmdTypeSlice: + return NewSliceCmd(ctx, args...) + case CmdTypeStringStructMap: + return NewStringStructMapCmd(ctx, args...) + case CmdTypeXMessageSlice: + return NewXMessageSliceCmd(ctx, args...) + case CmdTypeXStreamSlice: + return NewXStreamSliceCmd(ctx, args...) + case CmdTypeXPending: + return NewXPendingCmd(ctx, args...) + case CmdTypeXPendingExt: + return NewXPendingExtCmd(ctx, args...) + case CmdTypeXAutoClaim: + return NewXAutoClaimCmd(ctx, args...) + case CmdTypeXAutoClaimJustID: + return NewXAutoClaimJustIDCmd(ctx, args...) + case CmdTypeXInfoStreamFull: + return NewXInfoStreamFullCmd(ctx, args...) + case CmdTypeZSlice: + return NewZSliceCmd(ctx, args...) + case CmdTypeZWithKey: + return NewZWithKeyCmd(ctx, args...) + case CmdTypeClusterSlots: + return NewClusterSlotsCmd(ctx, args...) + case CmdTypeGeoPos: + return NewGeoPosCmd(ctx, args...) + case CmdTypeCommandsInfo: + return NewCommandsInfoCmd(ctx, args...) + case CmdTypeSlowLog: + return NewSlowLogCmd(ctx, args...) + case CmdTypeKeyValues: + return NewKeyValuesCmd(ctx, args...) + case CmdTypeZSliceWithKey: + return NewZSliceWithKeyCmd(ctx, args...) + case CmdTypeFunctionList: + return NewFunctionListCmd(ctx, args...) + case CmdTypeFunctionStats: + return NewFunctionStatsCmd(ctx, args...) + case CmdTypeKeyFlags: + return NewKeyFlagsCmd(ctx, args...) + case CmdTypeDuration: + return NewDurationCmd(ctx, time.Second, args...) + } + return NewCmd(ctx, args...) +} + +// executeSpecialCommand handles commands with special routing requirements +func (c *ClusterClient) executeSpecialCommand(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy, node *clusterNode) error { + switch cmd.Name() { + case "ft.cursor": + return c.executeCursorCommand(ctx, cmd) + default: + return c.executeDefault(ctx, cmd, node) + } +} + +// executeCursorCommand handles FT.CURSOR commands with sticky routing +func (c *ClusterClient) executeCursorCommand(ctx context.Context, cmd Cmder) error { + args := cmd.Args() + if len(args) < 4 { + return fmt.Errorf("redis: FT.CURSOR command requires at least 3 arguments") + } + + cursorID, ok := args[3].(string) + if !ok { + return fmt.Errorf("redis: invalid cursor ID type") + } + + // Route based on cursor ID to maintain stickiness + slot := hashtag.Slot(cursorID) + node, err := c.cmdNode(ctx, cmd.Name(), slot) + if err != nil { + return err + } + + return node.Client.Process(ctx, cmd) +} + +// executeParallel executes a command on multiple nodes concurrently +func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes []*clusterNode, policy *routing.CommandPolicy) error { + if len(nodes) == 0 { + return errClusterNoNodes + } + + if len(nodes) == 1 { + return nodes[0].Client.Process(ctx, cmd) + } + + type nodeResult struct { + cmd Cmder + err error + } + + results := make(chan nodeResult, len(nodes)) + var wg sync.WaitGroup + + for _, node := range nodes { + wg.Add(1) + go func(n *clusterNode) { + defer wg.Done() + cmdCopy := cmd.Clone() + err := n.Client.Process(ctx, cmdCopy) + results <- nodeResult{cmdCopy, err} + }(node) + } + + go func() { + wg.Wait() + close(results) + }() + + // Collect results and check for errors + cmds := make([]Cmder, 0, len(nodes)) + var firstErr error + + for result := range results { + if result.err != nil && firstErr == nil { + firstErr = result.err + } + cmds = append(cmds, result.cmd) + } + + // If there was an error and no policy specified, fail fast + if firstErr != nil && (policy == nil || policy.Response == routing.RespDefaultKeyless) { + cmd.SetErr(firstErr) + return firstErr + } + + return c.aggregateResponses(cmd, cmds, policy) +} + +// aggregateMultiSlotResults aggregates results from multi-slot execution +func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder, results <-chan slotResult, keyOrder []string, policy *routing.CommandPolicy) error { + keyedResults := make(map[string]interface{}) + var firstErr error + + for result := range results { + if result.err != nil && firstErr == nil { + firstErr = result.err + } + if result.cmd != nil && result.err == nil { + // For MGET, extract individual values from the array result + if strings.ToLower(cmd.Name()) == "mget" { + if sliceCmd, ok := result.cmd.(*SliceCmd); ok { + values := sliceCmd.Val() + if len(values) == len(result.keys) { + for i, key := range result.keys { + keyedResults[key] = values[i] + } + } else { + // Fallback: map all keys to the entire result + for _, key := range result.keys { + keyedResults[key] = values + } + } + } else { + // Fallback for non-SliceCmd results + value := ExtractCommandValue(result.cmd) + for _, key := range result.keys { + keyedResults[key] = value + } + } + } else { + // For other commands, map each key to the entire result + value := ExtractCommandValue(result.cmd) + for _, key := range result.keys { + keyedResults[key] = value + } + } + } + } + + if firstErr != nil { + cmd.SetErr(firstErr) + return firstErr + } + + return c.aggregateKeyedValues(cmd, keyedResults, keyOrder, policy) +} + +// aggregateKeyedValues aggregates individual key-value pairs while preserving key order +func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]interface{}, keyOrder []string, policy *routing.CommandPolicy) error { + if len(keyedResults) == 0 { + return fmt.Errorf("redis: no results to aggregate") + } + + aggregator := c.createAggregator(policy, cmd, true) + + // Set key order for keyed aggregators + if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { + keyedAgg.SetKeyOrder(keyOrder) + } + + // Add results with keys + for key, value := range keyedResults { + if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok { + if err := keyedAgg.AddWithKey(key, value, nil); err != nil { + return err + } + } else { + if err := aggregator.Add(value, nil); err != nil { + return err + } + } + } + + return c.finishAggregation(cmd, aggregator) +} + +// aggregateResponses aggregates multiple shard responses +func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *routing.CommandPolicy) error { + if len(cmds) == 0 { + return fmt.Errorf("redis: no commands to aggregate") + } + + if len(cmds) == 1 { + shardCmd := cmds[0] + if err := shardCmd.Err(); err != nil { + cmd.SetErr(err) + return err + } + value := ExtractCommandValue(shardCmd) + return c.setCommandValue(cmd, value) + } + + aggregator := c.createAggregator(policy, cmd, false) + + // Add all results to aggregator + for _, shardCmd := range cmds { + value := ExtractCommandValue(shardCmd) + if err := aggregator.Add(value, shardCmd.Err()); err != nil { + return err + } + } + + return c.finishAggregation(cmd, aggregator) +} + +// createAggregator creates the appropriate response aggregator +func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { + cmdName := strings.ToLower(cmd.Name()) + // For MGET without policy, use keyed aggregator + if cmdName == "mget" { + return routing.NewDefaultAggregator(true) + } + + if policy != nil { + return routing.NewResponseAggregator(policy.Response, cmd.Name()) + } + + if !isKeyed { + firstKeyPos := cmdFirstKeyPos(cmd) + isKeyed = firstKeyPos > 0 + } + + return routing.NewDefaultAggregator(isKeyed) +} + +// finishAggregation completes the aggregation process and sets the result +func (c *ClusterClient) finishAggregation(cmd Cmder, aggregator routing.ResponseAggregator) error { + finalValue, finalErr := aggregator.Finish() + if finalErr != nil { + cmd.SetErr(finalErr) + return finalErr + } + + return c.setCommandValue(cmd, finalValue) +} + +// pickArbitraryShard selects a master shard using the configured ShardPicker +func (c *ClusterClient) pickArbitraryShard(ctx context.Context) *clusterNode { + state, err := c.state.Get(ctx) + if err != nil || len(state.Masters) == 0 { + return nil + } + + idx := c.opt.ShardPicker.Next(len(state.Masters)) + return state.Masters[idx] +} + +// hasKeys checks if a command operates on keys +func (c *ClusterClient) hasKeys(cmd Cmder) bool { + firstKeyPos := cmdFirstKeyPos(cmd) + return firstKeyPos > 0 +} + +// setCommandValue sets the aggregated value on a command using the enum-based approach +func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { + // If value is nil, it might mean ExtractCommandValue couldn't extract the value + // but the command might have executed successfully. In this case, don't set an error. + if value == nil { + // Check if the original command has an error - if not, the nil value is not an error + if cmd.Err() == nil { + // Command executed successfully but value extraction failed + // This is common for complex commands like CLUSTER SLOTS + // The command already has its result set correctly, so just return + return nil + } + // If the command does have an error, set Nil error + cmd.SetErr(Nil) + return Nil + } + + switch cmd.GetCmdType() { + case CmdTypeGeneric: + if c, ok := cmd.(*Cmd); ok { + c.SetVal(value) + } + case CmdTypeString: + if c, ok := cmd.(*StringCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeInt: + if c, ok := cmd.(*IntCmd); ok { + if v, ok := value.(int64); ok { + c.SetVal(v) + } + } + case CmdTypeBool: + if c, ok := cmd.(*BoolCmd); ok { + if v, ok := value.(bool); ok { + c.SetVal(v) + } + } + case CmdTypeFloat: + if c, ok := cmd.(*FloatCmd); ok { + if v, ok := value.(float64); ok { + c.SetVal(v) + } + } + case CmdTypeStringSlice: + if c, ok := cmd.(*StringSliceCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v) + } + } + case CmdTypeIntSlice: + if c, ok := cmd.(*IntSliceCmd); ok { + if v, ok := value.([]int64); ok { + c.SetVal(v) + } + } + case CmdTypeFloatSlice: + if c, ok := cmd.(*FloatSliceCmd); ok { + if v, ok := value.([]float64); ok { + c.SetVal(v) + } + } + case CmdTypeBoolSlice: + if c, ok := cmd.(*BoolSliceCmd); ok { + if v, ok := value.([]bool); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringString: + if c, ok := cmd.(*MapStringStringCmd); ok { + if v, ok := value.(map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInt: + if c, ok := cmd.(*MapStringIntCmd); ok { + if v, ok := value.(map[string]int64); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInterface: + if c, ok := cmd.(*MapStringInterfaceCmd); ok { + if v, ok := value.(map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeSlice: + if c, ok := cmd.(*SliceCmd); ok { + if v, ok := value.([]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeStatus: + if c, ok := cmd.(*StatusCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeDuration: + if c, ok := cmd.(*DurationCmd); ok { + if v, ok := value.(time.Duration); ok { + c.SetVal(v) + } + } + case CmdTypeTime: + if c, ok := cmd.(*TimeCmd); ok { + if v, ok := value.(time.Time); ok { + c.SetVal(v) + } + } + case CmdTypeKeyValueSlice: + if c, ok := cmd.(*KeyValueSliceCmd); ok { + if v, ok := value.([]KeyValue); ok { + c.SetVal(v) + } + } + case CmdTypeStringStructMap: + if c, ok := cmd.(*StringStructMapCmd); ok { + if v, ok := value.(map[string]struct{}); ok { + c.SetVal(v) + } + } + case CmdTypeXMessageSlice: + if c, ok := cmd.(*XMessageSliceCmd); ok { + if v, ok := value.([]XMessage); ok { + c.SetVal(v) + } + } + case CmdTypeXStreamSlice: + if c, ok := cmd.(*XStreamSliceCmd); ok { + if v, ok := value.([]XStream); ok { + c.SetVal(v) + } + } + case CmdTypeXPending: + if c, ok := cmd.(*XPendingCmd); ok { + if v, ok := value.(*XPending); ok { + c.SetVal(v) + } + } + case CmdTypeXPendingExt: + if c, ok := cmd.(*XPendingExtCmd); ok { + if v, ok := value.([]XPendingExt); ok { + c.SetVal(v) + } + } + case CmdTypeXAutoClaim: + if c, ok := cmd.(*XAutoClaimCmd); ok { + if v, ok := value.([]XMessage); ok { + c.SetVal(v, "") // Default start value + } + } + case CmdTypeXAutoClaimJustID: + if c, ok := cmd.(*XAutoClaimJustIDCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v, "") // Default start value + } + } + case CmdTypeXInfoConsumers: + if c, ok := cmd.(*XInfoConsumersCmd); ok { + if v, ok := value.([]XInfoConsumer); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoGroups: + if c, ok := cmd.(*XInfoGroupsCmd); ok { + if v, ok := value.([]XInfoGroup); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoStream: + if c, ok := cmd.(*XInfoStreamCmd); ok { + if v, ok := value.(*XInfoStream); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoStreamFull: + if c, ok := cmd.(*XInfoStreamFullCmd); ok { + if v, ok := value.(*XInfoStreamFull); ok { + c.SetVal(v) + } + } + case CmdTypeZSlice: + if c, ok := cmd.(*ZSliceCmd); ok { + if v, ok := value.([]Z); ok { + c.SetVal(v) + } + } + case CmdTypeZWithKey: + if c, ok := cmd.(*ZWithKeyCmd); ok { + if v, ok := value.(*ZWithKey); ok { + c.SetVal(v) + } + } + case CmdTypeScan: + if c, ok := cmd.(*ScanCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v, uint64(0)) // Default cursor + } + } + case CmdTypeClusterSlots: + if c, ok := cmd.(*ClusterSlotsCmd); ok { + if v, ok := value.([]ClusterSlot); ok { + c.SetVal(v) + } + } + case CmdTypeGeoLocation: + if c, ok := cmd.(*GeoLocationCmd); ok { + if v, ok := value.([]GeoLocation); ok { + c.SetVal(v) + } + } + case CmdTypeGeoSearchLocation: + if c, ok := cmd.(*GeoSearchLocationCmd); ok { + if v, ok := value.([]GeoLocation); ok { + c.SetVal(v) + } + } + case CmdTypeGeoPos: + if c, ok := cmd.(*GeoPosCmd); ok { + if v, ok := value.([]*GeoPos); ok { + c.SetVal(v) + } + } + case CmdTypeCommandsInfo: + if c, ok := cmd.(*CommandsInfoCmd); ok { + if v, ok := value.(map[string]*CommandInfo); ok { + c.SetVal(v) + } + } + case CmdTypeSlowLog: + if c, ok := cmd.(*SlowLogCmd); ok { + if v, ok := value.([]SlowLog); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringStringSlice: + if c, ok := cmd.(*MapStringStringSliceCmd); ok { + if v, ok := value.([]map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMapMapStringInterface: + if c, ok := cmd.(*MapMapStringInterfaceCmd); ok { + if v, ok := value.(map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInterfaceSlice: + if c, ok := cmd.(*MapStringInterfaceSliceCmd); ok { + if v, ok := value.([]map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeKeyValues: + if c, ok := cmd.(*KeyValuesCmd); ok { + // KeyValuesCmd needs a key string and values slice + if key, ok := value.(string); ok { + c.SetVal(key, []string{}) // Default empty values + } + } + case CmdTypeZSliceWithKey: + if c, ok := cmd.(*ZSliceWithKeyCmd); ok { + // ZSliceWithKeyCmd needs a key string and Z slice + if key, ok := value.(string); ok { + c.SetVal(key, []Z{}) // Default empty Z slice + } + } + case CmdTypeFunctionList: + if c, ok := cmd.(*FunctionListCmd); ok { + if v, ok := value.([]Library); ok { + c.SetVal(v) + } + } + case CmdTypeFunctionStats: + if c, ok := cmd.(*FunctionStatsCmd); ok { + if v, ok := value.(FunctionStats); ok { + c.SetVal(v) + } + } + case CmdTypeLCS: + if c, ok := cmd.(*LCSCmd); ok { + if v, ok := value.(*LCSMatch); ok { + c.SetVal(v) + } + } + case CmdTypeKeyFlags: + if c, ok := cmd.(*KeyFlagsCmd); ok { + if v, ok := value.([]KeyFlags); ok { + c.SetVal(v) + } + } + case CmdTypeClusterLinks: + if c, ok := cmd.(*ClusterLinksCmd); ok { + if v, ok := value.([]ClusterLink); ok { + c.SetVal(v) + } + } + case CmdTypeClusterShards: + if c, ok := cmd.(*ClusterShardsCmd); ok { + if v, ok := value.([]ClusterShard); ok { + c.SetVal(v) + } + } + case CmdTypeRankWithScore: + if c, ok := cmd.(*RankWithScoreCmd); ok { + if v, ok := value.(RankScore); ok { + c.SetVal(v) + } + } + case CmdTypeClientInfo: + if c, ok := cmd.(*ClientInfoCmd); ok { + if v, ok := value.(*ClientInfo); ok { + c.SetVal(v) + } + } + case CmdTypeACLLog: + if c, ok := cmd.(*ACLLogCmd); ok { + if v, ok := value.([]*ACLLogEntry); ok { + c.SetVal(v) + } + } + case CmdTypeInfo: + if c, ok := cmd.(*InfoCmd); ok { + if v, ok := value.(map[string]map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMonitor: + // MonitorCmd doesn't have SetVal method + // Skip setting value for MonitorCmd + case CmdTypeJSON: + if c, ok := cmd.(*JSONCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeJSONSlice: + if c, ok := cmd.(*JSONSliceCmd); ok { + if v, ok := value.([]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeIntPointerSlice: + if c, ok := cmd.(*IntPointerSliceCmd); ok { + if v, ok := value.([]*int64); ok { + c.SetVal(v) + } + } + case CmdTypeScanDump: + if c, ok := cmd.(*ScanDumpCmd); ok { + if v, ok := value.(ScanDump); ok { + c.SetVal(v) + } + } + case CmdTypeBFInfo: + if c, ok := cmd.(*BFInfoCmd); ok { + if v, ok := value.(BFInfo); ok { + c.SetVal(v) + } + } + case CmdTypeCFInfo: + if c, ok := cmd.(*CFInfoCmd); ok { + if v, ok := value.(CFInfo); ok { + c.SetVal(v) + } + } + case CmdTypeCMSInfo: + if c, ok := cmd.(*CMSInfoCmd); ok { + if v, ok := value.(CMSInfo); ok { + c.SetVal(v) + } + } + case CmdTypeTopKInfo: + if c, ok := cmd.(*TopKInfoCmd); ok { + if v, ok := value.(TopKInfo); ok { + c.SetVal(v) + } + } + case CmdTypeTDigestInfo: + if c, ok := cmd.(*TDigestInfoCmd); ok { + if v, ok := value.(TDigestInfo); ok { + c.SetVal(v) + } + } + case CmdTypeFTSynDump: + if c, ok := cmd.(*FTSynDumpCmd); ok { + if v, ok := value.([]FTSynDumpResult); ok { + c.SetVal(v) + } + } + case CmdTypeAggregate: + if c, ok := cmd.(*AggregateCmd); ok { + if v, ok := value.(*FTAggregateResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTInfo: + if c, ok := cmd.(*FTInfoCmd); ok { + if v, ok := value.(FTInfoResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTSpellCheck: + if c, ok := cmd.(*FTSpellCheckCmd); ok { + if v, ok := value.([]SpellCheckResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTSearch: + if c, ok := cmd.(*FTSearchCmd); ok { + if v, ok := value.(FTSearchResult); ok { + c.SetVal(v) + } + } + case CmdTypeTSTimestampValue: + if c, ok := cmd.(*TSTimestampValueCmd); ok { + if v, ok := value.(TSTimestampValue); ok { + c.SetVal(v) + } + } + case CmdTypeTSTimestampValueSlice: + if c, ok := cmd.(*TSTimestampValueSliceCmd); ok { + if v, ok := value.([]TSTimestampValue); ok { + c.SetVal(v) + } + } + default: + // Fallback to reflection for unknown types + return c.setCommandValueReflection(cmd, value) + } + + return nil +} + +// setCommandValueReflection is a fallback function that uses reflection +func (c *ClusterClient) setCommandValueReflection(cmd Cmder, value interface{}) error { + cmdValue := reflect.ValueOf(cmd) + if cmdValue.Kind() != reflect.Ptr || cmdValue.IsNil() { + return fmt.Errorf("redis: invalid command pointer") + } + + setValMethod := cmdValue.MethodByName("SetVal") + if !setValMethod.IsValid() { + return fmt.Errorf("redis: command %T does not have SetVal method", cmd) + } + + args := []reflect.Value{reflect.ValueOf(value)} + + switch cmd.(type) { + case *XAutoClaimCmd, *XAutoClaimJustIDCmd: + args = append(args, reflect.ValueOf("")) + case *ScanCmd: + args = append(args, reflect.ValueOf(uint64(0))) + case *KeyValuesCmd, *ZSliceWithKeyCmd: + if key, ok := value.(string); ok { + args = []reflect.Value{reflect.ValueOf(key)} + if _, ok := cmd.(*ZSliceWithKeyCmd); ok { + args = append(args, reflect.ValueOf([]Z{})) + } else { + args = append(args, reflect.ValueOf([]string{})) + } + } + } + + defer func() { + if r := recover(); r != nil { + cmd.SetErr(fmt.Errorf("redis: failed to set command value: %v", r)) + } + }() + + setValMethod.Call(args) + return nil +} diff --git a/osscluster_test.go b/osscluster_test.go index ccf6daad8..04dbcf194 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -14,9 +14,9 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" - "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/internal/hashtag" + "github.com/redis/go-redis/v9/internal/routing" ) type clusterScenario struct { @@ -253,7 +253,7 @@ func slotEqual(s1, s2 redis.ClusterSlot) bool { return true } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ var _ = Describe("ClusterClient", func() { var failover bool @@ -277,7 +277,7 @@ var _ = Describe("ClusterClient", func() { Expect(cnt).To(Equal(int64(1))) }) - It("GET follows redirects", func() { + It("should follow redirects for GET", func() { err := client.Set(ctx, "A", "VALUE", 0).Err() Expect(err).NotTo(HaveOccurred()) @@ -300,7 +300,7 @@ var _ = Describe("ClusterClient", func() { Expect(v).To(Equal("VALUE")) }) - It("SET follows redirects", func() { + It("should follow redirects for SET", func() { if !failover { Eventually(func() error { return client.SwapNodes(ctx, "A") @@ -315,7 +315,7 @@ var _ = Describe("ClusterClient", func() { Expect(v).To(Equal("VALUE")) }) - It("distributes keys", func() { + It("should distribute keys", func() { for i := 0; i < 100; i++ { err := client.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err() Expect(err).NotTo(HaveOccurred()) @@ -336,7 +336,7 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) }) - It("distributes keys when using EVAL", func() { + It("should distribute keys when using EVAL", func() { script := redis.NewScript(` local r = redis.call('SET', KEYS[1], ARGV[1]) return r @@ -364,7 +364,7 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) }) - It("distributes scripts when using Script Load", func() { + It("should distribute scripts when using Script Load", func() { client.ScriptFlush(ctx) script := redis.NewScript(`return 'Unique script'`) @@ -381,7 +381,7 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) }) - It("checks all shards when using Script Exists", func() { + It("should check all shards when using Script Exists", func() { client.ScriptFlush(ctx) script := redis.NewScript(`return 'First script'`) @@ -396,7 +396,7 @@ var _ = Describe("ClusterClient", func() { Expect(val).To(Equal([]bool{true, false})) }) - It("flushes scripts from all shards when using ScriptFlush", func() { + It("should flush scripts from all shards when using ScriptFlush", func() { script := redis.NewScript(`return 'Unnecessary script'`) script.Load(ctx, client) @@ -409,7 +409,7 @@ var _ = Describe("ClusterClient", func() { Expect(val).To(Equal([]bool{false})) }) - It("supports Watch", func() { + It("should support Watch", func() { var incr func(string) error // Transactionally increments key using GET and SET commands. @@ -456,7 +456,7 @@ var _ = Describe("ClusterClient", func() { assertPipeline := func() { keys := []string{"A", "B", "C", "D", "E", "F", "G"} - It("follows redirects", func() { + It("should follow redirects", func() { if !failover { for _, key := range keys { Eventually(func() error { @@ -507,7 +507,7 @@ var _ = Describe("ClusterClient", func() { } }) - It("works with missing keys", func() { + It("should work with missing keys", func() { pipe.Set(ctx, "A", "A_value", 0) pipe.Set(ctx, "C", "C_value", 0) _, err := pipe.Exec(ctx) @@ -540,7 +540,7 @@ var _ = Describe("ClusterClient", func() { assertPipeline() - It("doesn't fail node with context.Canceled error", func() { + It("should not fail node with context.Canceled error", func() { ctx, cancel := context.WithCancel(context.Background()) cancel() pipe.Set(ctx, "A", "A_value", 0) @@ -556,7 +556,7 @@ var _ = Describe("ClusterClient", func() { } }) - It("doesn't fail node with context.DeadlineExceeded error", func() { + It("should not fail node with context.DeadlineExceeded error", func() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() @@ -585,7 +585,7 @@ var _ = Describe("ClusterClient", func() { }) }) - It("supports PubSub", func() { + It("should support PubSub", func() { pubsub := client.Subscribe(ctx, "mychannel") defer pubsub.Close() @@ -609,7 +609,7 @@ var _ = Describe("ClusterClient", func() { }, 30*time.Second).ShouldNot(HaveOccurred()) }) - It("supports sharded PubSub", func() { + It("should support sharded PubSub", func() { pubsub := client.SSubscribe(ctx, "mychannel") defer pubsub.Close() @@ -633,7 +633,7 @@ var _ = Describe("ClusterClient", func() { }, 30*time.Second).ShouldNot(HaveOccurred()) }) - It("supports PubSub.Ping without channels", func() { + It("should support PubSub.Ping without channels", func() { pubsub := client.Subscribe(ctx) defer pubsub.Close() @@ -690,12 +690,12 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("returns pool stats", func() { + It("should return pool stats", func() { stats := client.PoolStats() Expect(stats).To(BeAssignableToTypeOf(&redis.PoolStats{})) }) - It("returns an error when there are no attempts left", func() { + It("should return an error when there are no attempts left", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 client := cluster.newClusterClient(ctx, opt) @@ -711,7 +711,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("determines hash slots correctly for generic commands", func() { + It("should determine hash slots correctly for generic commands", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 client := cluster.newClusterClient(ctx, opt) @@ -737,7 +737,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("follows node redirection immediately", func() { + It("should follow node redirection immediately", func() { // Configure retry backoffs far in excess of the expected duration of redirection opt := redisClusterOptions() opt.MinRetryBackoff = 10 * time.Minute @@ -763,7 +763,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("calls fn for every master node", func() { + It("should call fn for every master node", func() { for i := 0; i < 10; i++ { Expect(client.Set(ctx, strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred()) } @@ -976,7 +976,7 @@ var _ = Describe("ClusterClient", func() { Expect(len(keys)).To(BeNumerically("~", nkeys, nkeys/10)) }) - It("supports Process hook", func() { + It("should support Process hook", func() { testCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -988,6 +988,7 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) + var mu sync.Mutex var stack []string clusterHook := &hook{ @@ -1000,12 +1001,16 @@ var _ = Describe("ClusterClient", func() { } Expect(cmd.String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "cluster.BeforeProcess") + mu.Unlock() err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "cluster.AfterProcess") + mu.Unlock() return err } @@ -1023,12 +1028,16 @@ var _ = Describe("ClusterClient", func() { } Expect(cmd.String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "shard.BeforeProcess") + mu.Unlock() err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "shard.AfterProcess") + mu.Unlock() return err } @@ -1042,7 +1051,13 @@ var _ = Describe("ClusterClient", func() { err = client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) - Expect(stack).To(Equal([]string{ + + mu.Lock() + finalStack := make([]string, len(stack)) + copy(finalStack, stack) + mu.Unlock() + + Expect(finalStack).To(ContainElements([]string{ "cluster.BeforeProcess", "shard.BeforeProcess", "shard.AfterProcess", @@ -1050,7 +1065,7 @@ var _ = Describe("ClusterClient", func() { })) }) - It("supports Pipeline hook", func() { + It("should support Pipeline hook", func() { err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -1059,22 +1074,35 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) + var mu sync.Mutex var stack []string client.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "cluster.BeforeProcessPipeline") + cmdStr := cmds[0].String() - err := hook(ctx, cmds) + // Handle SET command (should succeed) + if cmdStr == "set pipeline_test_key pipeline_test_value: " { + mu.Lock() + stack = append(stack, "cluster.BeforeProcessPipeline") + mu.Unlock() - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "cluster.AfterProcessPipeline") + err := hook(ctx, cmds) - return err + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("set pipeline_test_key pipeline_test_value: OK")) + mu.Lock() + stack = append(stack, "cluster.AfterProcessPipeline") + mu.Unlock() + + return err + } + + // For other commands (like ping), just pass through without expectations + // since they might fail before reaching this point + return hook(ctx, cmds) } }, }) @@ -1084,16 +1112,27 @@ var _ = Describe("ClusterClient", func() { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) - stack = append(stack, "shard.BeforeProcessPipeline") + cmdStr := cmds[0].String() - err := hook(ctx, cmds) + // Handle SET command (should succeed) + if cmdStr == "set pipeline_test_key pipeline_test_value: " { + mu.Lock() + stack = append(stack, "shard.BeforeProcessPipeline") + mu.Unlock() - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) - stack = append(stack, "shard.AfterProcessPipeline") + err := hook(ctx, cmds) - return err + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("set pipeline_test_key pipeline_test_value: OK")) + mu.Lock() + stack = append(stack, "shard.AfterProcessPipeline") + mu.Unlock() + + return err + } + + // For other commands (like ping), just pass through without expectations + return hook(ctx, cmds) } }, }) @@ -1101,11 +1140,17 @@ var _ = Describe("ClusterClient", func() { }) _, err = client.Pipelined(ctx, func(pipe redis.Pipeliner) error { - pipe.Ping(ctx) + pipe.Set(ctx, "pipeline_test_key", "pipeline_test_value", 0) return nil }) Expect(err).NotTo(HaveOccurred()) - Expect(stack).To(Equal([]string{ + + mu.Lock() + finalStack := make([]string, len(stack)) + copy(finalStack, stack) + mu.Unlock() + + Expect(finalStack).To(Equal([]string{ "cluster.BeforeProcessPipeline", "shard.BeforeProcessPipeline", "shard.AfterProcessPipeline", @@ -1113,7 +1158,17 @@ var _ = Describe("ClusterClient", func() { })) }) - It("supports TxPipeline hook", func() { + It("should reject ping command in pipeline", func() { + // Test that ping command fails in pipeline as expected + _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Ping(ctx) + return nil + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("redis: cannot pipeline command \"ping\" with request policy ReqAllNodes/ReqAllShards/ReqMultiShard")) + }) + + It("should support TxPipeline hook", func() { err := client.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) @@ -1122,6 +1177,7 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) + var mu sync.Mutex var stack []string client.AddHook(&hook{ @@ -1129,13 +1185,17 @@ var _ = Describe("ClusterClient", func() { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "cluster.BeforeProcessPipeline") + mu.Unlock() err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "cluster.AfterProcessPipeline") + mu.Unlock() return err } @@ -1148,13 +1208,17 @@ var _ = Describe("ClusterClient", func() { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) + mu.Lock() stack = append(stack, "shard.BeforeProcessPipeline") + mu.Unlock() err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) + mu.Lock() stack = append(stack, "shard.AfterProcessPipeline") + mu.Unlock() return err } @@ -1168,7 +1232,13 @@ var _ = Describe("ClusterClient", func() { return nil }) Expect(err).NotTo(HaveOccurred()) - Expect(stack).To(Equal([]string{ + + mu.Lock() + finalStack := make([]string, len(stack)) + copy(finalStack, stack) + mu.Unlock() + + Expect(finalStack).To(Equal([]string{ "cluster.BeforeProcessPipeline", "shard.BeforeProcessPipeline", "shard.AfterProcessPipeline", @@ -1405,12 +1475,12 @@ var _ = Describe("ClusterClient without nodes", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("Ping returns an error", func() { + It("should return an error for Ping", func() { err := client.Ping(ctx).Err() Expect(err).To(MatchError("redis: cluster has no nodes")) }) - It("pipeline returns an error", func() { + It("should return an error for pipeline", func() { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil @@ -1432,12 +1502,12 @@ var _ = Describe("ClusterClient without valid nodes", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("returns an error", func() { + It("should return an error when cluster support is disabled", func() { err := client.Ping(ctx).Err() Expect(err).To(MatchError("ERR This instance has cluster support disabled")) }) - It("pipeline returns an error", func() { + It("should return an error for pipeline when cluster support is disabled", func() { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil @@ -1467,7 +1537,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("recovers when Cluster recovers", func() { + It("should recover when Cluster recovers", func() { err := client.Ping(ctx).Err() Expect(err).To(HaveOccurred()) @@ -1485,13 +1555,13 @@ var _ = Describe("ClusterClient timeout", func() { }) testTimeout := func() { - It("Ping timeouts", func() { + It("should timeout Ping", func() { err := client.Ping(ctx).Err() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("Pipeline timeouts", func() { + It("should timeout Pipeline", func() { _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil @@ -1500,7 +1570,7 @@ var _ = Describe("ClusterClient timeout", func() { Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("Tx timeouts", func() { + It("should timeout Tx", func() { err := client.Watch(ctx, func(tx *redis.Tx) error { return tx.Ping(ctx).Err() }, "foo") @@ -1508,7 +1578,7 @@ var _ = Describe("ClusterClient timeout", func() { Expect(err.(net.Error).Timeout()).To(BeTrue()) }) - It("Tx Pipeline timeouts", func() { + It("should timeout Tx Pipeline", func() { err := client.Watch(ctx, func(tx *redis.Tx) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) @@ -1567,134 +1637,1168 @@ var _ = Describe("ClusterClient timeout", func() { }) }) -var _ = Describe("ClusterClient ParseURL", func() { - cases := []struct { - test string - url string - o *redis.ClusterOptions // expected value - err error - }{ - { - test: "ParseRedisURL", - url: "redis://localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}}, - }, { - test: "ParseRedissURL", - url: "rediss://localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "MissingRedisPort", - url: "redis://localhost", - o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}}, - }, { - test: "MissingRedissPort", - url: "rediss://localhost", - o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "MultipleRedisURLs", - url: "redis://localhost:123?addr=localhost:1234&addr=localhost:12345", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}}, - }, { - test: "MultipleRedissURLs", - url: "rediss://localhost:123?addr=localhost:1234&addr=localhost:12345", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "OnlyPassword", - url: "redis://:bar@localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Password: "bar"}, - }, { - test: "OnlyUser", - url: "redis://foo@localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo"}, - }, { - test: "RedisUsernamePassword", - url: "redis://foo:bar@localhost:123", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo", Password: "bar"}, - }, { - test: "RedissUsernamePassword", - url: "rediss://foo:bar@localhost:123?addr=localhost:1234", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ServerName: "localhost"}}, - }, { - test: "QueryParameters", - url: "redis://localhost:123?read_timeout=2&pool_fifo=true&addr=localhost:1234", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, ReadTimeout: 2 * time.Second, PoolFIFO: true}, - }, { - test: "DisabledTimeout", - url: "redis://localhost:123?conn_max_idle_time=0", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, - }, { - test: "DisabledTimeoutNeg", - url: "redis://localhost:123?conn_max_idle_time=-1", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, - }, { - test: "UseDefault", - url: "redis://localhost:123?conn_max_idle_time=", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, - }, { - test: "Protocol", - url: "redis://localhost:123?protocol=2", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Protocol: 2}, - }, { - test: "ClientName", - url: "redis://localhost:123?client_name=cluster_hi", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ClientName: "cluster_hi"}, - }, { - test: "UseDefaultMissing=", - url: "redis://localhost:123?conn_max_idle_time", - o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, - }, { - test: "InvalidQueryAddr", - url: "rediss://foo:bar@localhost:123?addr=rediss://foo:barr@localhost:1234", - err: errors.New(`redis: unable to parse addr param: rediss://foo:barr@localhost:1234`), - }, { - test: "InvalidInt", - url: "redis://localhost?pool_size=five", - err: errors.New(`redis: invalid pool_size number: strconv.Atoi: parsing "five": invalid syntax`), - }, { - test: "InvalidBool", - url: "redis://localhost?pool_fifo=yes", - err: errors.New(`redis: invalid pool_fifo boolean: expected true/false/1/0 or an empty string, got "yes"`), - }, { - test: "UnknownParam", - url: "redis://localhost?abc=123", - err: errors.New("redis: unexpected option: abc"), - }, { - test: "InvalidScheme", - url: "https://google.com", - err: errors.New("redis: invalid URL scheme: https"), - }, - } +var _ = Describe("Command Tips tests", func() { + var client *redis.ClusterClient + + BeforeEach(func() { + opt := redisClusterOptions() + client = cluster.newClusterClient(ctx, opt) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should verify COMMAND tips match router policy types", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + expectedPolicies := map[string]struct { + RequestPolicy string + ResponsePolicy string + }{ + "touch": { + RequestPolicy: "multi_shard", + ResponsePolicy: "agg_sum", + }, + "flushall": { + RequestPolicy: "all_shards", + ResponsePolicy: "all_succeeded", + }, + } + + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + for cmdName, expected := range expectedPolicies { + actualCmd := cmds[cmdName] + + Expect(actualCmd.Tips).NotTo(BeNil()) + + // Verify request_policy from COMMAND matches router policy + actualRequestPolicy := actualCmd.Tips.Request.String() + Expect(actualRequestPolicy).To(Equal(expected.RequestPolicy)) + + // Verify response_policy from COMMAND matches router policy + actualResponsePolicy := actualCmd.Tips.Response.String() + Expect(actualResponsePolicy).To(Equal(expected.ResponsePolicy)) + } + }) + + Describe("Explicit Routing Policy Tests", func() { + It("should test explicit routing policy for TOUCH", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify TOUCH command has multi_shard policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + touchCmd := cmds["touch"] + + Expect(touchCmd.Tips).NotTo(BeNil()) + Expect(touchCmd.Tips.Request.String()).To(Equal("multi_shard")) + Expect(touchCmd.Tips.Response.String()).To(Equal("agg_sum")) + + keys := []string{"key1", "key2", "key3", "key4", "key5"} + for _, key := range keys { + err := client.Set(ctx, key, "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + } - It("match ParseClusterURL", func() { - for i := range cases { - tc := cases[i] - actual, err := redis.ParseClusterURL(tc.url) - if tc.err != nil { - Expect(err).Should(MatchError(tc.err)) - } else { + result := client.Touch(ctx, keys...) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal(int64(len(keys)))) + }) + + It("should test explicit routing policy for FLUSHALL", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify FLUSHALL command has all_shards policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + flushallCmd := cmds["flushall"] + + Expect(flushallCmd.Tips).NotTo(BeNil()) + Expect(flushallCmd.Tips.Request.String()).To(Equal("all_shards")) + Expect(flushallCmd.Tips.Response.String()).To(Equal("all_succeeded")) + + testKeys := []string{"test1", "test2", "test3"} + for _, key := range testKeys { + err := client.Set(ctx, key, "value", 0).Err() Expect(err).NotTo(HaveOccurred()) } - if err == nil { - Expect(tc.o).NotTo(BeNil()) - - Expect(tc.o.Addrs).To(Equal(actual.Addrs)) - Expect(tc.o.TLSConfig).To(Equal(actual.TLSConfig)) - Expect(tc.o.Username).To(Equal(actual.Username)) - Expect(tc.o.Password).To(Equal(actual.Password)) - Expect(tc.o.MaxRetries).To(Equal(actual.MaxRetries)) - Expect(tc.o.MinRetryBackoff).To(Equal(actual.MinRetryBackoff)) - Expect(tc.o.MaxRetryBackoff).To(Equal(actual.MaxRetryBackoff)) - Expect(tc.o.DialTimeout).To(Equal(actual.DialTimeout)) - Expect(tc.o.ReadTimeout).To(Equal(actual.ReadTimeout)) - Expect(tc.o.WriteTimeout).To(Equal(actual.WriteTimeout)) - Expect(tc.o.PoolFIFO).To(Equal(actual.PoolFIFO)) - Expect(tc.o.PoolSize).To(Equal(actual.PoolSize)) - Expect(tc.o.MinIdleConns).To(Equal(actual.MinIdleConns)) - Expect(tc.o.ConnMaxLifetime).To(Equal(actual.ConnMaxLifetime)) - Expect(tc.o.ConnMaxIdleTime).To(Equal(actual.ConnMaxIdleTime)) - Expect(tc.o.PoolTimeout).To(Equal(actual.PoolTimeout)) + err = client.FlushAll(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + for _, key := range testKeys { + exists := client.Exists(ctx, key) + Expect(exists.Val()).To(Equal(int64(0))) } + }) + + It("should test explicit routing policy for PING", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify PING command has all_shards policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + pingCmd := cmds["ping"] + Expect(pingCmd.Tips).NotTo(BeNil()) + Expect(pingCmd.Tips.Request.String()).To(Equal("all_shards")) + Expect(pingCmd.Tips.Response.String()).To(Equal("all_succeeded")) + + result := client.Ping(ctx) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("PONG")) + }) + + It("should test explicit routing policy for DBSIZE", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify DBSIZE command has all_shards policy with agg_sum response + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + dbsizeCmd := cmds["dbsize"] + Expect(dbsizeCmd.Tips).NotTo(BeNil()) + Expect(dbsizeCmd.Tips.Request.String()).To(Equal("all_shards")) + Expect(dbsizeCmd.Tips.Response.String()).To(Equal("agg_sum")) + + testKeys := []string{"dbsize_test1", "dbsize_test2", "dbsize_test3"} + for _, key := range testKeys { + err := client.Set(ctx, key, "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + + size := client.DBSize(ctx) + Expect(size.Err()).NotTo(HaveOccurred()) + Expect(size.Val()).To(BeNumerically(">=", int64(len(testKeys)))) + }) + }) + + Describe("DDL Commands Routing Policy Tests", func() { + BeforeEach(func() { + info := client.Info(ctx, "modules") + if info.Err() != nil || !strings.Contains(info.Val(), "search") { + Skip("Search module not available") + } + }) + + It("should test DDL commands routing policy for FT.CREATE", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify FT.CREATE command routing policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + ftCreateCmd, exists := cmds["ft.create"] + if !exists || ftCreateCmd.Tips == nil { + Skip("FT.CREATE command or tips not available") + } + + // DDL commands should NOT be broadcasted - they should go to coordinator only + Expect(ftCreateCmd.Tips).NotTo(BeNil()) + requestPolicy := ftCreateCmd.Tips.Request.String() + Expect(requestPolicy).NotTo(Equal("all_shards")) + Expect(requestPolicy).NotTo(Equal("all_nodes")) + + indexName := "test_index_create" + client.FTDropIndex(ctx, indexName) + + result := client.FTCreate(ctx, indexName, + &redis.FTCreateOptions{ + OnHash: true, + Prefix: []interface{}{"doc:"}, + }, + &redis.FieldSchema{ + FieldName: "title", + FieldType: redis.SearchFieldTypeText, + }) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("OK")) + + infoResult := client.FTInfo(ctx, indexName) + Expect(infoResult.Err()).NotTo(HaveOccurred()) + Expect(infoResult.Val().IndexName).To(Equal(indexName)) + client.FTDropIndex(ctx, indexName) + }) + + It("should test DDL commands routing policy for FT.ALTER", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Verify FT.ALTER command routing policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + ftAlterCmd, exists := cmds["ft.alter"] + if !exists || ftAlterCmd.Tips == nil { + Skip("FT.ALTER command or tips not available") + } + + Expect(ftAlterCmd.Tips).NotTo(BeNil()) + requestPolicy := ftAlterCmd.Tips.Request.String() + Expect(requestPolicy).NotTo(Equal("all_shards")) + Expect(requestPolicy).NotTo(Equal("all_nodes")) + + indexName := "test_index_alter" + client.FTDropIndex(ctx, indexName) + + result := client.FTCreate(ctx, indexName, + &redis.FTCreateOptions{ + OnHash: true, + Prefix: []interface{}{"doc:"}, + }, + &redis.FieldSchema{ + FieldName: "title", + FieldType: redis.SearchFieldTypeText, + }) + Expect(result.Err()).NotTo(HaveOccurred()) + + alterResult := client.FTAlter(ctx, indexName, false, + []interface{}{"description", redis.SearchFieldTypeText.String()}) + Expect(alterResult.Err()).NotTo(HaveOccurred()) + Expect(alterResult.Val()).To(Equal("OK")) + client.FTDropIndex(ctx, indexName) + }) + + It("should route keyed commands to correct shard based on hash slot", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // Single keyed command should go to exactly one shard - determined by hash slot + testKey := "test_key_12345" + testValue := "test_value" + + result := client.Set(ctx, testKey, testValue, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("OK")) + + time.Sleep(200 * time.Millisecond) + + var targetNodeAddr string + foundNodes := 0 + + for _, node := range masterNodes { + getResult := node.client.Get(ctx, testKey) + if getResult.Err() == nil && getResult.Val() == testValue { + foundNodes++ + targetNodeAddr = node.addr + } else { + } + } + + Expect(foundNodes).To(Equal(1)) + Expect(targetNodeAddr).NotTo(BeEmpty()) + + // Multiple commands with same key should go to same shard + finalValue := "" + for i := 0; i < 5; i++ { + finalValue = fmt.Sprintf("value_%d", i) + result := client.Set(ctx, testKey, finalValue, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("OK")) + } + + time.Sleep(200 * time.Millisecond) + + var currentTargetNode string + foundNodesAfterUpdate := 0 + + for _, node := range masterNodes { + getResult := node.client.Get(ctx, testKey) + if getResult.Err() == nil && getResult.Val() == finalValue { + foundNodesAfterUpdate++ + currentTargetNode = node.addr + } else { + } + } + + // All commands with same key should go to same shard + Expect(foundNodesAfterUpdate).To(Equal(1)) + Expect(currentTargetNode).To(Equal(targetNodeAddr)) + }) + + It("should aggregate responses according to explicit aggregation policies", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // verify TOUCH command has agg_sum policy + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + touchCmd, exists := cmds["touch"] + if !exists || touchCmd.Tips == nil { + Skip("TOUCH command or tips not available") + } + + Expect(touchCmd.Tips.Response.String()).To(Equal("agg_sum")) + + testKeys := []string{ + "touch_test_key_1111", // These keys should map to different hash slots + "touch_test_key_2222", + "touch_test_key_3333", + "touch_test_key_4444", + "touch_test_key_5555", + } + + // Set keys on different shards + keysPerShard := make(map[string][]string) + for _, key := range testKeys { + result := client.Set(ctx, key, "test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + + // Find which shard contains this key + for _, node := range masterNodes { + getResult := node.client.Get(ctx, key) + if getResult.Err() == nil { + keysPerShard[node.addr] = append(keysPerShard[node.addr], key) + break + } + } + } + + // Verify keys are distributed across multiple shards + shardsWithKeys := len(keysPerShard) + Expect(shardsWithKeys).To(BeNumerically(">", 1)) + + // Execute TOUCH command on all keys - this should aggregate results using agg_sum + touchResult := client.Touch(ctx, testKeys...) + Expect(touchResult.Err()).NotTo(HaveOccurred()) + + totalTouched := touchResult.Val() + Expect(totalTouched).To(Equal(int64(len(testKeys)))) + + totalKeysOnShards := 0 + for _, keys := range keysPerShard { + totalKeysOnShards += len(keys) + } + + Expect(totalKeysOnShards).To(Equal(len(testKeys))) + + // FLUSHALL command with all_succeeded aggregation policy + flushallCmd, exists := cmds["flushall"] + if !exists || flushallCmd.Tips == nil { + Skip("FLUSHALL command or tips not available") + } + + Expect(flushallCmd.Tips.Response.String()).To(Equal("all_succeeded")) + + for i := 0; i < len(masterNodes); i++ { + testKey := fmt.Sprintf("flush_test_key_%d_%d", i, time.Now().UnixNano()) + result := client.Set(ctx, testKey, "test_data", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + flushResult := client.FlushAll(ctx) + Expect(flushResult.Err()).NotTo(HaveOccurred()) + Expect(flushResult.Val()).To(Equal("OK")) + + for _, node := range masterNodes { + dbSizeResult := node.client.DBSize(ctx) + Expect(dbSizeResult.Err()).NotTo(HaveOccurred()) + Expect(dbSizeResult.Val()).To(Equal(int64(0))) + } + + // WAIT command aggregation policy - verify agg_min policy + waitCmd, exists := cmds["wait"] + if !exists || waitCmd.Tips == nil { + Skip("WAIT command or tips not available") + } + + Expect(waitCmd.Tips.Response.String()).To(Equal("agg_min")) + + // Set up some data to replicate + testKey := "wait_test_key_1111" + result := client.Set(ctx, testKey, "test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + + // Execute WAIT command - should aggregate using agg_min across all shards + // WAIT waits for a given number of replicas to acknowledge writes + // With agg_min policy, it returns the minimum number of replicas that acknowledged + waitResult := client.Wait(ctx, 0, 1000) // Wait for 0 replicas with 1 second timeout + Expect(waitResult.Err()).NotTo(HaveOccurred()) + + // The result should be the minimum number of replicas across all shards + // Since we're asking for 0 replicas, all shards should return 0, so min is 0 + minReplicas := waitResult.Val() + Expect(minReplicas).To(BeNumerically(">=", 0)) + + // SCRIPT EXISTS command aggregation policy - verify agg_logical_and policy + scriptExistsCmd, exists := cmds["script exists"] + if !exists || scriptExistsCmd.Tips == nil { + Skip("SCRIPT EXISTS command or tips not available") + } + + Expect(scriptExistsCmd.Tips.Response.String()).To(Equal("agg_logical_and")) + + // Load a script on all shards + testScript := "return 'hello'" + scriptLoadResult := client.ScriptLoad(ctx, testScript) + Expect(scriptLoadResult.Err()).NotTo(HaveOccurred()) + scriptSHA := scriptLoadResult.Val() + + // Verify script exists on all shards using SCRIPT EXISTS + // With agg_logical_and policy, it should return true only if script exists on ALL shards + scriptExistsResult := client.ScriptExists(ctx, scriptSHA) + Expect(scriptExistsResult.Err()).NotTo(HaveOccurred()) + + existsResults := scriptExistsResult.Val() + Expect(len(existsResults)).To(Equal(1)) + Expect(existsResults[0]).To(BeTrue()) // Script should exist on all shards + + // Test with a non-existent script SHA + nonExistentSHA := "0000000000000000000000000000000000000000" + scriptExistsResult2 := client.ScriptExists(ctx, nonExistentSHA) + Expect(scriptExistsResult2.Err()).NotTo(HaveOccurred()) + + existsResults2 := scriptExistsResult2.Val() + Expect(len(existsResults2)).To(Equal(1)) + Expect(existsResults2[0]).To(BeFalse()) // Script should not exist on any shard + + // Test with mixed scenario - flush scripts from one shard manually + // This is harder to test in practice since SCRIPT FLUSH affects all shards + // So we'll just verify the basic functionality works + }) + + It("should verify command aggregation policies", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + commandPolicies := map[string]string{ + "touch": "agg_sum", + "flushall": "all_succeeded", + "pfcount": "all_succeeded", + "exists": "agg_sum", + "script exists": "agg_logical_and", + "wait": "agg_min", + } + + for cmdName, expectedPolicy := range commandPolicies { + cmd, exists := cmds[cmdName] + if !exists { + continue + } + + if cmd.Tips == nil { + continue + } + + actualPolicy := cmd.Tips.Response.String() + Expect(actualPolicy).To(Equal(expectedPolicy)) + } + }) + + It("should properly aggregate responses from keyless commands executed on multiple shards", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // PING command with all_shards policy - should aggregate responses + cmds, err := client.Command(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + + pingCmd, exists := cmds["ping"] + if exists && pingCmd.Tips != nil { + } + + pingResult := client.Ping(ctx) + Expect(pingResult.Err()).NotTo(HaveOccurred()) + Expect(pingResult.Val()).To(Equal("PONG")) + + // Verify PING was executed on all shards by checking individual nodes + for _, node := range masterNodes { + nodePingResult := node.client.Ping(ctx) + Expect(nodePingResult.Err()).NotTo(HaveOccurred()) + Expect(nodePingResult.Val()).To(Equal("PONG")) + } + + // Test 2: DBSIZE command aggregation across shards - verify agg_sum policy + testKeys := []string{ + "dbsize_test_key_1111", + "dbsize_test_key_2222", + "dbsize_test_key_3333", + "dbsize_test_key_4444", + } + + for _, key := range testKeys { + result := client.Set(ctx, key, "test_value", 0) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + dbSizeResult := client.DBSize(ctx) + Expect(dbSizeResult.Err()).NotTo(HaveOccurred()) + + totalSize := dbSizeResult.Val() + Expect(totalSize).To(BeNumerically(">=", int64(len(testKeys)))) + + // Verify aggregation by manually getting sizes from each shard + totalManualSize := int64(0) + + for _, node := range masterNodes { + nodeDbSizeResult := node.client.DBSize(ctx) + Expect(nodeDbSizeResult.Err()).NotTo(HaveOccurred()) + + nodeSize := nodeDbSizeResult.Val() + totalManualSize += nodeSize + } + + // Verify aggregation worked correctly + Expect(totalSize).To(Equal(totalManualSize)) + }) + + It("should properly aggregate responses from keyed commands executed on multiple shards", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 1)) + + // MGET command aggregation across multiple keys on different shards - verify all_succeeded policy with keyed aggregation + testData := map[string]string{ + "mget_test_key_1111": "value1", + "mget_test_key_2222": "value2", + "mget_test_key_3333": "value3", + "mget_test_key_4444": "value4", + "mget_test_key_5555": "value5", + } + + keyLocations := make(map[string]string) + for key, value := range testData { + + result := client.Set(ctx, key, value, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + + for _, node := range masterNodes { + getResult := node.client.Get(ctx, key) + if getResult.Err() == nil && getResult.Val() == value { + keyLocations[key] = node.addr + break + } + } + } + + shardsUsed := make(map[string]bool) + for _, shardAddr := range keyLocations { + shardsUsed[shardAddr] = true + } + Expect(len(shardsUsed)).To(BeNumerically(">", 1)) + + keys := make([]string, 0, len(testData)) + expectedValues := make([]interface{}, 0, len(testData)) + + for key, value := range testData { + keys = append(keys, key) + expectedValues = append(expectedValues, value) + } + + mgetResult := client.MGet(ctx, keys...) + Expect(mgetResult.Err()).NotTo(HaveOccurred()) + + actualValues := mgetResult.Val() + Expect(len(actualValues)).To(Equal(len(keys))) + Expect(actualValues).To(ConsistOf(expectedValues)) + + // Verify all values are correctly aggregated + for i, key := range keys { + expectedValue := testData[key] + actualValue := actualValues[i] + Expect(actualValue).To(Equal(expectedValue)) + } + + // DEL command aggregation across multiple keys on different shards + delResult := client.Del(ctx, keys...) + Expect(delResult.Err()).NotTo(HaveOccurred()) + + deletedCount := delResult.Val() + Expect(deletedCount).To(Equal(int64(len(keys)))) + + // Verify keys are actually deleted from their respective shards + for key, shardAddr := range keyLocations { + var targetNode *masterNode + for i := range masterNodes { + if masterNodes[i].addr == shardAddr { + targetNode = &masterNodes[i] + break + } + } + Expect(targetNode).NotTo(BeNil()) + + getResult := targetNode.client.Get(ctx, key) + Expect(getResult.Err()).To(HaveOccurred()) + } + + // EXISTS command aggregation across multiple keys + existsTestData := map[string]string{ + "exists_agg_key_1111": "value1", + "exists_agg_key_2222": "value2", + "exists_agg_key_3333": "value3", + } + + existsKeys := make([]string, 0, len(existsTestData)) + for key, value := range existsTestData { + result := client.Set(ctx, key, value, 0) + Expect(result.Err()).NotTo(HaveOccurred()) + existsKeys = append(existsKeys, key) + } + + // Add a non-existent key to the list + nonExistentKey := "non_existent_key_9999" + existsKeys = append(existsKeys, nonExistentKey) + + existsResult := client.Exists(ctx, existsKeys...) + Expect(existsResult.Err()).NotTo(HaveOccurred()) + + existsCount := existsResult.Val() + Expect(existsCount).To(Equal(int64(len(existsTestData)))) + }) + + It("should propagate coordinator errors to client without modification", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + type masterNode struct { + client *redis.Client + addr string + } + var masterNodes []masterNode + var mu sync.Mutex + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + mu.Lock() + masterNodes = append(masterNodes, masterNode{ + client: master, + addr: addr, + }) + mu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(masterNodes)).To(BeNumerically(">", 0)) + + invalidSlotResult := client.ClusterAddSlotsRange(ctx, 99999, 100000) + coordinatorErr := invalidSlotResult.Err() + + if coordinatorErr != nil { + // Verify the error is a Redis error + var redisErr redis.Error + Expect(errors.As(coordinatorErr, &redisErr)).To(BeTrue()) + + // Verify error message is preserved exactly as returned by coordinator + errorMsg := coordinatorErr.Error() + Expect(errorMsg).To(SatisfyAny( + ContainSubstring("slot"), + ContainSubstring("ERR"), + ContainSubstring("Invalid"), + )) + + // Test that the same error occurs when calling coordinator directly + coordinatorNode := masterNodes[0] + directResult := coordinatorNode.client.ClusterAddSlotsRange(ctx, 99999, 100000) + directErr := directResult.Err() + + if directErr != nil { + Expect(coordinatorErr.Error()).To(Equal(directErr.Error())) + } + } + + // Try cluster forget with invalid node ID + invalidNodeID := "invalid_node_id_12345" + forgetResult := client.ClusterForget(ctx, invalidNodeID) + forgetErr := forgetResult.Err() + + if forgetErr != nil { + var redisErr redis.Error + Expect(errors.As(forgetErr, &redisErr)).To(BeTrue()) + + errorMsg := forgetErr.Error() + Expect(errorMsg).To(SatisfyAny( + ContainSubstring("Unknown node"), + ContainSubstring("Invalid node"), + ContainSubstring("ERR"), + )) + + coordinatorNode := masterNodes[0] + directForgetResult := coordinatorNode.client.ClusterForget(ctx, invalidNodeID) + directForgetErr := directForgetResult.Err() + + if directForgetErr != nil { + Expect(forgetErr.Error()).To(Equal(directForgetErr.Error())) + } + } + + // Test error type preservation and format + keySlotResult := client.ClusterKeySlot(ctx, "") + keySlotErr := keySlotResult.Err() + + if keySlotErr != nil { + var redisErr redis.Error + Expect(errors.As(keySlotErr, &redisErr)).To(BeTrue()) + + errorMsg := keySlotErr.Error() + Expect(len(errorMsg)).To(BeNumerically(">", 0)) + Expect(errorMsg).NotTo(ContainSubstring("wrapped")) + Expect(errorMsg).NotTo(ContainSubstring("context")) + } + + // Verify error propagation consistency + clusterInfoResult := client.ClusterInfo(ctx) + clusterInfoErr := clusterInfoResult.Err() + + if clusterInfoErr != nil { + var redisErr redis.Error + Expect(errors.As(clusterInfoErr, &redisErr)).To(BeTrue()) + + coordinatorNode := masterNodes[0] + directInfoResult := coordinatorNode.client.ClusterInfo(ctx) + directInfoErr := directInfoResult.Err() + + if directInfoErr != nil { + Expect(clusterInfoErr.Error()).To(Equal(directInfoErr.Error())) + } + } + + // Verify no error modification in router + invalidReplicateResult := client.ClusterReplicate(ctx, "00000000000000000000000000000000invalid00") + invalidReplicateErr := invalidReplicateResult.Err() + + if invalidReplicateErr != nil { + var redisErr redis.Error + Expect(errors.As(invalidReplicateErr, &redisErr)).To(BeTrue()) + + errorMsg := invalidReplicateErr.Error() + Expect(errorMsg).NotTo(ContainSubstring("router")) + Expect(errorMsg).NotTo(ContainSubstring("cluster client")) + Expect(errorMsg).NotTo(ContainSubstring("failed to execute")) + + Expect(errorMsg).To(SatisfyAny( + HavePrefix("ERR"), + ContainSubstring("Invalid"), + ContainSubstring("Unknown"), + )) + } + }) + + It("should route keyless commands to arbitrary shards using round robin", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + var numMasters int + var numMastersMu sync.Mutex + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + numMastersMu.Lock() + numMasters++ + numMastersMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(numMasters).To(BeNumerically(">", 1)) + + err = client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.ConfigResetStat(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + // Helper function to get ECHO command counts from all nodes + getEchoCounts := func() map[string]int { + echoCounts := make(map[string]int) + var echoCountsMu sync.Mutex + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + info := master.Info(ctx, "server") + Expect(info.Err()).NotTo(HaveOccurred()) + + serverInfo := info.Val() + portStart := strings.Index(serverInfo, "tcp_port:") + portLine := serverInfo[portStart:] + portEnd := strings.Index(portLine, "\r\n") + if portEnd == -1 { + portEnd = len(portLine) + } + port := strings.TrimPrefix(portLine[:portEnd], "tcp_port:") + + commandStats := master.Info(ctx, "commandstats") + count := 0 + if commandStats.Err() == nil { + stats := commandStats.Val() + cmdStatKey := "cmdstat_echo:" + if strings.Contains(stats, cmdStatKey) { + statStart := strings.Index(stats, cmdStatKey) + statLine := stats[statStart:] + statEnd := strings.Index(statLine, "\r\n") + if statEnd == -1 { + statEnd = len(statLine) + } + statLine = statLine[:statEnd] + + callsStart := strings.Index(statLine, "calls=") + if callsStart != -1 { + callsStr := statLine[callsStart+6:] + callsEnd := strings.Index(callsStr, ",") + if callsEnd == -1 { + callsEnd = strings.Index(callsStr, "\r") + if callsEnd == -1 { + callsEnd = len(callsStr) + } + } + if callsCount, err := strconv.Atoi(callsStr[:callsEnd]); err == nil { + count = callsCount + } + } + } + } + + echoCountsMu.Lock() + echoCounts[port] = count + echoCountsMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + return echoCounts + } + + // Single ECHO command should go to exactly one shard + result := client.Echo(ctx, "single_test") + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal("single_test")) + + time.Sleep(200 * time.Millisecond) + + // Verify single command went to exactly one shard + echoCounts := getEchoCounts() + shardsWithEcho := 0 + for _, count := range echoCounts { + if count > 0 { + shardsWithEcho++ + Expect(count).To(Equal(1)) + } + } + Expect(shardsWithEcho).To(Equal(1)) + + // Test Multiple ECHO commands should distribute across all shards using round robin + numCommands := numMasters * 3 + + for i := 0; i < numCommands; i++ { + result := client.Echo(ctx, fmt.Sprintf("multi_test_%d", i)) + Expect(result.Err()).NotTo(HaveOccurred()) + Expect(result.Val()).To(Equal(fmt.Sprintf("multi_test_%d", i))) + } + + time.Sleep(200 * time.Millisecond) + + echoCounts = getEchoCounts() + totalEchos := 0 + shardsWithEchos := 0 + for _, count := range echoCounts { + if count > 0 { + shardsWithEchos++ + } + totalEchos += count + } + + // All shards should now have some ECHO commands + Expect(shardsWithEchos).To(Equal(numMasters)) + + expectedTotal := 1 + numCommands + Expect(totalEchos).To(Equal(expectedTotal)) + }) + }) + + var _ = Describe("ClusterClient ParseURL", func() { + cases := []struct { + test string + url string + o *redis.ClusterOptions // expected value + err error + }{ + { + test: "ParseRedisURL", + url: "redis://localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}}, + }, { + test: "ParseRedissURL", + url: "rediss://localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "MissingRedisPort", + url: "redis://localhost", + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}}, + }, { + test: "MissingRedissPort", + url: "rediss://localhost", + o: &redis.ClusterOptions{Addrs: []string{"localhost:6379"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "MultipleRedisURLs", + url: "redis://localhost:123?addr=localhost:1234&addr=localhost:12345", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}}, + }, { + test: "MultipleRedissURLs", + url: "rediss://localhost:123?addr=localhost:1234&addr=localhost:12345", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234", "localhost:12345"}, TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "OnlyPassword", + url: "redis://:bar@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Password: "bar"}, + }, { + test: "OnlyUser", + url: "redis://foo@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo"}, + }, { + test: "RedisUsernamePassword", + url: "redis://foo:bar@localhost:123", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Username: "foo", Password: "bar"}, + }, { + test: "RedissUsernamePassword", + url: "rediss://foo:bar@localhost:123?addr=localhost:1234", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, Username: "foo", Password: "bar", TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { + test: "QueryParameters", + url: "redis://localhost:123?read_timeout=2&pool_fifo=true&addr=localhost:1234", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123", "localhost:1234"}, ReadTimeout: 2 * time.Second, PoolFIFO: true}, + }, { + test: "DisabledTimeout", + url: "redis://localhost:123?conn_max_idle_time=0", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, + }, { + test: "DisabledTimeoutNeg", + url: "redis://localhost:123?conn_max_idle_time=-1", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: -1}, + }, { + test: "UseDefault", + url: "redis://localhost:123?conn_max_idle_time=", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, + }, { + test: "Protocol", + url: "redis://localhost:123?protocol=2", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, Protocol: 2}, + }, { + test: "ClientName", + url: "redis://localhost:123?client_name=cluster_hi", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ClientName: "cluster_hi"}, + }, { + test: "UseDefaultMissing=", + url: "redis://localhost:123?conn_max_idle_time", + o: &redis.ClusterOptions{Addrs: []string{"localhost:123"}, ConnMaxIdleTime: 0}, + }, { + test: "InvalidQueryAddr", + url: "rediss://foo:bar@localhost:123?addr=rediss://foo:barr@localhost:1234", + err: errors.New(`redis: unable to parse addr param: rediss://foo:barr@localhost:1234`), + }, { + test: "InvalidInt", + url: "redis://localhost?pool_size=five", + err: errors.New(`redis: invalid pool_size number: strconv.Atoi: parsing "five": invalid syntax`), + }, { + test: "InvalidBool", + url: "redis://localhost?pool_fifo=yes", + err: errors.New(`redis: invalid pool_fifo boolean: expected true/false/1/0 or an empty string, got "yes"`), + }, { + test: "UnknownParam", + url: "redis://localhost?abc=123", + err: errors.New("redis: unexpected option: abc"), + }, { + test: "InvalidScheme", + url: "https://google.com", + err: errors.New("redis: invalid URL scheme: https"), + }, } + + It("should match ParseClusterURL", func() { + for i := range cases { + tc := cases[i] + actual, err := redis.ParseClusterURL(tc.url) + if tc.err != nil { + Expect(err).Should(MatchError(tc.err)) + } else { + Expect(err).NotTo(HaveOccurred()) + } + + if err == nil { + Expect(tc.o).NotTo(BeNil()) + + Expect(tc.o.Addrs).To(Equal(actual.Addrs)) + Expect(tc.o.TLSConfig).To(Equal(actual.TLSConfig)) + Expect(tc.o.Username).To(Equal(actual.Username)) + Expect(tc.o.Password).To(Equal(actual.Password)) + Expect(tc.o.MaxRetries).To(Equal(actual.MaxRetries)) + Expect(tc.o.MinRetryBackoff).To(Equal(actual.MinRetryBackoff)) + Expect(tc.o.MaxRetryBackoff).To(Equal(actual.MaxRetryBackoff)) + Expect(tc.o.DialTimeout).To(Equal(actual.DialTimeout)) + Expect(tc.o.ReadTimeout).To(Equal(actual.ReadTimeout)) + Expect(tc.o.WriteTimeout).To(Equal(actual.WriteTimeout)) + Expect(tc.o.PoolFIFO).To(Equal(actual.PoolFIFO)) + Expect(tc.o.PoolSize).To(Equal(actual.PoolSize)) + Expect(tc.o.MinIdleConns).To(Equal(actual.MinIdleConns)) + Expect(tc.o.ConnMaxLifetime).To(Equal(actual.ConnMaxLifetime)) + Expect(tc.o.ConnMaxIdleTime).To(Equal(actual.ConnMaxIdleTime)) + Expect(tc.o.PoolTimeout).To(Equal(actual.PoolTimeout)) + } + } + }) + + It("should distribute keyless commands randomly across shards using random shard picker", func() { + SkipBeforeRedisVersion(7.9, "The tips are included from Redis 8") + + // Create a cluster client with random shard picker + opt := redisClusterOptions() + opt.ShardPicker = &routing.RandomPicker{} + randomClient := cluster.newClusterClient(ctx, opt) + defer randomClient.Close() + + Eventually(func() error { + return randomClient.Ping(ctx).Err() + }, 30*time.Second).ShouldNot(HaveOccurred()) + + var numMasters int + var numMastersMu sync.Mutex + err := randomClient.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + numMastersMu.Lock() + numMasters++ + numMastersMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(numMasters).To(BeNumerically(">", 1)) + + // Reset command statistics on all masters + err = randomClient.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.ConfigResetStat(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + // Helper function to get ECHO command counts from all nodes + getEchoCounts := func() map[string]int { + echoCounts := make(map[string]int) + var echoCountsMu sync.Mutex + err := randomClient.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + addr := master.Options().Addr + port := addr[strings.LastIndex(addr, ":")+1:] + + info, err := master.Info(ctx, "commandstats").Result() + if err != nil { + return err + } + + count := 0 + if strings.Contains(info, "cmdstat_echo:") { + lines := strings.Split(info, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "cmdstat_echo:") { + parts := strings.Split(line, ",") + if len(parts) > 0 { + callsPart := strings.Split(parts[0], "=") + if len(callsPart) > 1 { + if parsedCount, parseErr := strconv.Atoi(callsPart[1]); parseErr == nil { + count = parsedCount + } + } + } + break + } + } + } + + echoCountsMu.Lock() + echoCounts[port] = count + echoCountsMu.Unlock() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + return echoCounts + } + + // Execute multiple ECHO commands and measure distribution + numCommands := 100 + for i := 0; i < numCommands; i++ { + result := randomClient.Echo(ctx, fmt.Sprintf("random_test_%d", i)) + Expect(result.Err()).NotTo(HaveOccurred()) + } + + echoCounts := getEchoCounts() + + totalEchos := 0 + shardsWithEchos := 0 + + for _, count := range echoCounts { + if count > 0 { + shardsWithEchos++ + } + totalEchos += count + } + + Expect(totalEchos).To(Equal(numCommands)) + Expect(shardsWithEchos).To(BeNumerically(">=", 2)) + }) }) }) diff --git a/probabilistic.go b/probabilistic.go index 02ca263cb..b70765807 100644 --- a/probabilistic.go +++ b/probabilistic.go @@ -225,8 +225,9 @@ type ScanDumpCmd struct { func newScanDumpCmd(ctx context.Context, args ...interface{}) *ScanDumpCmd { return &ScanDumpCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeScanDump, }, } } @@ -270,6 +271,13 @@ func (cmd *ScanDumpCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ScanDumpCmd) Clone() Cmder { + return &ScanDumpCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // ScanDump is a simple struct, can be copied directly + } +} + // Returns information about a Bloom filter. // For more information - https://redis.io/commands/bf.info/ func (c cmdable) BFInfo(ctx context.Context, key string) *BFInfoCmd { @@ -296,8 +304,9 @@ type BFInfoCmd struct { func NewBFInfoCmd(ctx context.Context, args ...interface{}) *BFInfoCmd { return &BFInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBFInfo, }, } } @@ -388,6 +397,13 @@ func (cmd *BFInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *BFInfoCmd) Clone() Cmder { + return &BFInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // BFInfo is a simple struct, can be copied directly + } +} + // BFInfoCapacity returns information about the capacity of a Bloom filter. // For more information - https://redis.io/commands/bf.info/ func (c cmdable) BFInfoCapacity(ctx context.Context, key string) *BFInfoCmd { @@ -625,8 +641,9 @@ type CFInfoCmd struct { func NewCFInfoCmd(ctx context.Context, args ...interface{}) *CFInfoCmd { return &CFInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCFInfo, }, } } @@ -692,6 +709,13 @@ func (cmd *CFInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *CFInfoCmd) Clone() Cmder { + return &CFInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // CFInfo is a simple struct, can be copied directly + } +} + // CFInfo returns information about a Cuckoo filter. // For more information - https://redis.io/commands/cf.info/ func (c cmdable) CFInfo(ctx context.Context, key string) *CFInfoCmd { @@ -787,8 +811,9 @@ type CMSInfoCmd struct { func NewCMSInfoCmd(ctx context.Context, args ...interface{}) *CMSInfoCmd { return &CMSInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCMSInfo, }, } } @@ -843,6 +868,13 @@ func (cmd *CMSInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *CMSInfoCmd) Clone() Cmder { + return &CMSInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // CMSInfo is a simple struct, can be copied directly + } +} + // CMSInfo returns information about a Count-Min Sketch filter. // For more information - https://redis.io/commands/cms.info/ func (c cmdable) CMSInfo(ctx context.Context, key string) *CMSInfoCmd { @@ -980,8 +1012,9 @@ type TopKInfoCmd struct { func NewTopKInfoCmd(ctx context.Context, args ...interface{}) *TopKInfoCmd { return &TopKInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTopKInfo, }, } } @@ -1038,6 +1071,13 @@ func (cmd *TopKInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TopKInfoCmd) Clone() Cmder { + return &TopKInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TopKInfo is a simple struct, can be copied directly + } +} + // TopKInfo returns information about a Top-K filter. // For more information - https://redis.io/commands/topk.info/ func (c cmdable) TopKInfo(ctx context.Context, key string) *TopKInfoCmd { @@ -1243,8 +1283,9 @@ type TDigestInfoCmd struct { func NewTDigestInfoCmd(ctx context.Context, args ...interface{}) *TDigestInfoCmd { return &TDigestInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTDigestInfo, }, } } @@ -1311,6 +1352,13 @@ func (cmd *TDigestInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TDigestInfoCmd) Clone() Cmder { + return &TDigestInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TDigestInfo is a simple struct, can be copied directly + } +} + // TDigestInfo returns information about a t-Digest data structure. // For more information - https://redis.io/commands/tdigest.info/ func (c cmdable) TDigestInfo(ctx context.Context, key string) *TDigestInfoCmd { diff --git a/search_commands.go b/search_commands.go index b31baaa76..c69853bf0 100644 --- a/search_commands.go +++ b/search_commands.go @@ -657,8 +657,9 @@ func ProcessAggregateResult(data []interface{}) (*FTAggregateResult, error) { func NewAggregateCmd(ctx context.Context, args ...interface{}) *AggregateCmd { return &AggregateCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeAggregate, }, } } @@ -699,6 +700,31 @@ func (cmd *AggregateCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *AggregateCmd) Clone() Cmder { + var val *FTAggregateResult + if cmd.val != nil { + val = &FTAggregateResult{ + Total: cmd.val.Total, + } + if cmd.val.Rows != nil { + val.Rows = make([]AggregateRow, len(cmd.val.Rows)) + for i, row := range cmd.val.Rows { + val.Rows[i] = AggregateRow{} + if row.Fields != nil { + val.Rows[i].Fields = make(map[string]interface{}, len(row.Fields)) + for k, v := range row.Fields { + val.Rows[i].Fields[k] = v + } + } + } + } + } + return &AggregateCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTAggregateWithArgs - Performs a search query on an index and applies a series of aggregate transformations to the result. // The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query. // This function also allows for specifying additional options such as: Verbatim, LoadAll, Load, Timeout, GroupBy, SortBy, SortByMax, Apply, LimitOffset, Limit, Filter, WithCursor, Params, and DialectVersion. @@ -1382,8 +1408,9 @@ type FTInfoCmd struct { func newFTInfoCmd(ctx context.Context, args ...interface{}) *FTInfoCmd { return &FTInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTInfo, }, } } @@ -1445,6 +1472,68 @@ func (cmd *FTInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *FTInfoCmd) Clone() Cmder { + val := FTInfoResult{ + IndexErrors: cmd.val.IndexErrors, + BytesPerRecordAvg: cmd.val.BytesPerRecordAvg, + Cleaning: cmd.val.Cleaning, + CursorStats: cmd.val.CursorStats, + DocTableSizeMB: cmd.val.DocTableSizeMB, + GCStats: cmd.val.GCStats, + GeoshapesSzMB: cmd.val.GeoshapesSzMB, + HashIndexingFailures: cmd.val.HashIndexingFailures, + IndexDefinition: cmd.val.IndexDefinition, + IndexName: cmd.val.IndexName, + Indexing: cmd.val.Indexing, + InvertedSzMB: cmd.val.InvertedSzMB, + KeyTableSizeMB: cmd.val.KeyTableSizeMB, + MaxDocID: cmd.val.MaxDocID, + NumDocs: cmd.val.NumDocs, + NumRecords: cmd.val.NumRecords, + NumTerms: cmd.val.NumTerms, + NumberOfUses: cmd.val.NumberOfUses, + OffsetBitsPerRecordAvg: cmd.val.OffsetBitsPerRecordAvg, + OffsetVectorsSzMB: cmd.val.OffsetVectorsSzMB, + OffsetsPerTermAvg: cmd.val.OffsetsPerTermAvg, + PercentIndexed: cmd.val.PercentIndexed, + RecordsPerDocAvg: cmd.val.RecordsPerDocAvg, + SortableValuesSizeMB: cmd.val.SortableValuesSizeMB, + TagOverheadSzMB: cmd.val.TagOverheadSzMB, + TextOverheadSzMB: cmd.val.TextOverheadSzMB, + TotalIndexMemorySzMB: cmd.val.TotalIndexMemorySzMB, + TotalIndexingTime: cmd.val.TotalIndexingTime, + TotalInvertedIndexBlocks: cmd.val.TotalInvertedIndexBlocks, + VectorIndexSzMB: cmd.val.VectorIndexSzMB, + } + // Clone slices and maps + if cmd.val.Attributes != nil { + val.Attributes = make([]FTAttribute, len(cmd.val.Attributes)) + copy(val.Attributes, cmd.val.Attributes) + } + if cmd.val.DialectStats != nil { + val.DialectStats = make(map[string]int, len(cmd.val.DialectStats)) + for k, v := range cmd.val.DialectStats { + val.DialectStats[k] = v + } + } + if cmd.val.FieldStatistics != nil { + val.FieldStatistics = make([]FieldStatistic, len(cmd.val.FieldStatistics)) + copy(val.FieldStatistics, cmd.val.FieldStatistics) + } + if cmd.val.IndexOptions != nil { + val.IndexOptions = make([]string, len(cmd.val.IndexOptions)) + copy(val.IndexOptions, cmd.val.IndexOptions) + } + if cmd.val.IndexDefinition.Prefixes != nil { + val.IndexDefinition.Prefixes = make([]string, len(cmd.val.IndexDefinition.Prefixes)) + copy(val.IndexDefinition.Prefixes, cmd.val.IndexDefinition.Prefixes) + } + return &FTInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTInfo - Retrieves information about an index. // The 'index' parameter specifies the index to retrieve information about. // For more information, please refer to the Redis documentation: @@ -1501,8 +1590,9 @@ type FTSpellCheckCmd struct { func newFTSpellCheckCmd(ctx context.Context, args ...interface{}) *FTSpellCheckCmd { return &FTSpellCheckCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSpellCheck, }, } } @@ -1598,6 +1688,26 @@ func parseFTSpellCheck(data []interface{}) ([]SpellCheckResult, error) { return results, nil } +func (cmd *FTSpellCheckCmd) Clone() Cmder { + var val []SpellCheckResult + if cmd.val != nil { + val = make([]SpellCheckResult, len(cmd.val)) + for i, result := range cmd.val { + val[i] = SpellCheckResult{ + Term: result.Term, + } + if result.Suggestions != nil { + val[i].Suggestions = make([]SpellCheckSuggestion, len(result.Suggestions)) + copy(val[i].Suggestions, result.Suggestions) + } + } + } + return &FTSpellCheckCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + func parseFTSearch(data []interface{}, noContent, withScores, withPayloads, withSortKeys bool) (FTSearchResult, error) { if len(data) < 1 { return FTSearchResult{}, fmt.Errorf("unexpected search result format") @@ -1688,8 +1798,9 @@ type FTSearchCmd struct { func newFTSearchCmd(ctx context.Context, options *FTSearchOptions, args ...interface{}) *FTSearchCmd { return &FTSearchCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSearch, }, options: options, } @@ -1731,6 +1842,89 @@ func (cmd *FTSearchCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *FTSearchCmd) Clone() Cmder { + val := FTSearchResult{ + Total: cmd.val.Total, + } + if cmd.val.Docs != nil { + val.Docs = make([]Document, len(cmd.val.Docs)) + for i, doc := range cmd.val.Docs { + val.Docs[i] = Document{ + ID: doc.ID, + Score: doc.Score, + Payload: doc.Payload, + SortKey: doc.SortKey, + } + if doc.Fields != nil { + val.Docs[i].Fields = make(map[string]string, len(doc.Fields)) + for k, v := range doc.Fields { + val.Docs[i].Fields[k] = v + } + } + } + } + var options *FTSearchOptions + if cmd.options != nil { + options = &FTSearchOptions{ + NoContent: cmd.options.NoContent, + Verbatim: cmd.options.Verbatim, + NoStopWords: cmd.options.NoStopWords, + WithScores: cmd.options.WithScores, + WithPayloads: cmd.options.WithPayloads, + WithSortKeys: cmd.options.WithSortKeys, + Slop: cmd.options.Slop, + Timeout: cmd.options.Timeout, + InOrder: cmd.options.InOrder, + Language: cmd.options.Language, + Expander: cmd.options.Expander, + Scorer: cmd.options.Scorer, + ExplainScore: cmd.options.ExplainScore, + Payload: cmd.options.Payload, + SortByWithCount: cmd.options.SortByWithCount, + LimitOffset: cmd.options.LimitOffset, + Limit: cmd.options.Limit, + CountOnly: cmd.options.CountOnly, + DialectVersion: cmd.options.DialectVersion, + } + // Clone slices and maps + if cmd.options.Filters != nil { + options.Filters = make([]FTSearchFilter, len(cmd.options.Filters)) + copy(options.Filters, cmd.options.Filters) + } + if cmd.options.GeoFilter != nil { + options.GeoFilter = make([]FTSearchGeoFilter, len(cmd.options.GeoFilter)) + copy(options.GeoFilter, cmd.options.GeoFilter) + } + if cmd.options.InKeys != nil { + options.InKeys = make([]interface{}, len(cmd.options.InKeys)) + copy(options.InKeys, cmd.options.InKeys) + } + if cmd.options.InFields != nil { + options.InFields = make([]interface{}, len(cmd.options.InFields)) + copy(options.InFields, cmd.options.InFields) + } + if cmd.options.Return != nil { + options.Return = make([]FTSearchReturn, len(cmd.options.Return)) + copy(options.Return, cmd.options.Return) + } + if cmd.options.SortBy != nil { + options.SortBy = make([]FTSearchSortBy, len(cmd.options.SortBy)) + copy(options.SortBy, cmd.options.SortBy) + } + if cmd.options.Params != nil { + options.Params = make(map[string]interface{}, len(cmd.options.Params)) + for k, v := range cmd.options.Params { + options.Params[k] = v + } + } + } + return &FTSearchCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + options: options, + } +} + // FTSearch - Executes a search query on an index. // The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query. // For more information, please refer to the Redis documentation about [FT.SEARCH]. @@ -1988,8 +2182,9 @@ func (c cmdable) FTSearchWithArgs(ctx context.Context, index string, query strin func NewFTSynDumpCmd(ctx context.Context, args ...interface{}) *FTSynDumpCmd { return &FTSynDumpCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSynDump, }, } } @@ -2055,6 +2250,26 @@ func (cmd *FTSynDumpCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *FTSynDumpCmd) Clone() Cmder { + var val []FTSynDumpResult + if cmd.val != nil { + val = make([]FTSynDumpResult, len(cmd.val)) + for i, result := range cmd.val { + val[i] = FTSynDumpResult{ + Term: result.Term, + } + if result.Synonyms != nil { + val[i].Synonyms = make([]string, len(result.Synonyms)) + copy(val[i].Synonyms, result.Synonyms) + } + } + } + return &FTSynDumpCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTSynDump - Dumps the contents of a synonym group. // The 'index' parameter specifies the index to dump. // For more information, please refer to the Redis documentation: diff --git a/timeseries_commands.go b/timeseries_commands.go index 82d8cdfcf..71ed6af23 100644 --- a/timeseries_commands.go +++ b/timeseries_commands.go @@ -486,8 +486,9 @@ type TSTimestampValueCmd struct { func newTSTimestampValueCmd(ctx context.Context, args ...interface{}) *TSTimestampValueCmd { return &TSTimestampValueCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTSTimestampValue, }, } } @@ -533,6 +534,13 @@ func (cmd *TSTimestampValueCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TSTimestampValueCmd) Clone() Cmder { + return &TSTimestampValueCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TSTimestampValue is a simple struct, can be copied directly + } +} + // TSInfo - Returns information about a time-series key. // For more information - https://redis.io/commands/ts.info/ func (c cmdable) TSInfo(ctx context.Context, key string) *MapStringInterfaceCmd { @@ -704,8 +712,9 @@ type TSTimestampValueSliceCmd struct { func newTSTimestampValueSliceCmd(ctx context.Context, args ...interface{}) *TSTimestampValueSliceCmd { return &TSTimestampValueSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTSTimestampValueSlice, }, } } @@ -752,6 +761,18 @@ func (cmd *TSTimestampValueSliceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TSTimestampValueSliceCmd) Clone() Cmder { + var val []TSTimestampValue + if cmd.val != nil { + val = make([]TSTimestampValue, len(cmd.val)) + copy(val, cmd.val) + } + return &TSTimestampValueSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // TSMRange - Returns a range of samples from multiple time-series keys. // For more information - https://redis.io/commands/ts.mrange/ func (c cmdable) TSMRange(ctx context.Context, fromTimestamp int, toTimestamp int, filterExpr []string) *MapStringSliceInterfaceCmd {