From a949b588aa5722dddf87a9ed3d054cedbb146429 Mon Sep 17 00:00:00 2001 From: Nick Rosbrook Date: Thu, 22 Aug 2024 13:13:04 -0400 Subject: [PATCH] message: add safePutUint{8,16,32} helpers Add helpers to check for overflow before writing an integer in big endian form to a buffer. This helps address potential integer overflow bugs. --- vici/client_conn.go | 5 +-- vici/client_conn_test.go | 14 +++++--- vici/message.go | 72 +++++++++++++++++++++++++++++++--------- vici/packet.go | 2 +- vici/packet_test.go | 18 ++++++++++ 5 files changed, 86 insertions(+), 25 deletions(-) diff --git a/vici/client_conn.go b/vici/client_conn.go index b5d5c3a..361e7d3 100644 --- a/vici/client_conn.go +++ b/vici/client_conn.go @@ -93,10 +93,7 @@ func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error { } // Write the packet length - pl := make([]byte, headerLength) - binary.BigEndian.PutUint32(pl, uint32(len(b))) - _, err = buf.Write(pl) - if err != nil { + if err := safePutUint32(buf, len(b)); err != nil { r <- err return } diff --git a/vici/client_conn_test.go b/vici/client_conn_test.go index 35b3734..f6f9d24 100644 --- a/vici/client_conn_test.go +++ b/vici/client_conn_test.go @@ -56,7 +56,7 @@ func TestPacketWrite(t *testing.T) { length := binary.BigEndian.Uint32(b) - if want := len(goldNamedPacketBytes); length != uint32(want) { + if want := len(goldNamedPacketBytes); length != uint32(want) { // #nosec G115 t.Errorf("Unexpected packet length: got %d, expected: %d", length, want) } @@ -108,13 +108,17 @@ func TestPacketRead(t *testing.T) { }() // Make a buffer big enough for the data and the header. - b := make([]byte, headerLength+len(goldNamedPacketBytes)) + buf := new(bytes.Buffer) - binary.BigEndian.PutUint32(b[:headerLength], uint32(len(goldNamedPacketBytes))) + if err := safePutUint32(buf, len(goldNamedPacketBytes)); err != nil { + t.Fatalf("Unexpected error writing header: %v", err) + } - copy(b[headerLength:], goldNamedPacketBytes) + if _, err := buf.Write(goldNamedPacketBytes); err != nil { + t.Fatalf("Unexpected error writing packet: %v", err) + } - _, err := srvr.Write(b) + _, err := srvr.Write(buf.Bytes()) if err != nil { t.Fatalf("Unexpected error sending bytes: %v", err) } diff --git a/vici/message.go b/vici/message.go index 6dad275..17e2b90 100644 --- a/vici/message.go +++ b/vici/message.go @@ -262,6 +262,57 @@ func (m *Message) elements() []messageElement { return ordered } +func safePutUint8(buf *bytes.Buffer, val int) error { + limit := ^uint8(0) + + if int64(val) > int64(limit) { + return fmt.Errorf("val too long (%d > %d)", val, limit) + } + + // We can safely convert now, because we just checked that it will not overflow. + if err := buf.WriteByte(uint8(val)); err != nil { // #nosec G115 + return err + } + + return nil +} + +func safePutUint16(buf *bytes.Buffer, val int) error { + limit := ^uint16(0) + b := make([]byte, 2) + + if int64(val) > int64(limit) { + return fmt.Errorf("val too long (%d > %d)", val, limit) + } + + // We can safely convert now, because we just checked that it will not overflow. + binary.BigEndian.PutUint16(b, uint16(val)) // #nosec G115 + + if _, err := buf.Write(b); err != nil { + return err + } + + return nil +} + +func safePutUint32(buf *bytes.Buffer, val int) error { + limit := ^uint32(0) + b := make([]byte, 4) + + if int64(val) > int64(limit) { + return fmt.Errorf("val too long (%d > %d)", val, limit) + } + + // We can safely convert now, because we just checked that it will not overflow. + binary.BigEndian.PutUint32(b, uint32(val)) // #nosec G115 + + if _, err := buf.Write(b); err != nil { + return err + } + + return nil +} + func (m *Message) encode() ([]byte, error) { buf := bytes.NewBuffer([]byte{}) @@ -372,22 +423,17 @@ func (m *Message) encodeKeyValue(key, value string) ([]byte, error) { buf := bytes.NewBuffer([]byte{msgKeyValue}) // Write the key length and key - err := buf.WriteByte(uint8(len(key))) - if err != nil { + if err := safePutUint8(buf, len(key)); err != nil { return nil, fmt.Errorf("%v: %v", errEncoding, err) } - _, err = buf.WriteString(key) + _, err := buf.WriteString(key) if err != nil { return nil, fmt.Errorf("%v: %v", errEncoding, err) } // Write the value's length to the buffer as two bytes - vl := make([]byte, 2) - binary.BigEndian.PutUint16(vl, uint16(len(value))) - - _, err = buf.Write(vl) - if err != nil { + if err := safePutUint16(buf, len(value)); err != nil { return nil, fmt.Errorf("%v: %v", errEncoding, err) } @@ -412,7 +458,7 @@ func (m *Message) encodeList(key string, list []string) ([]byte, error) { buf := bytes.NewBuffer([]byte{msgListStart}) // Write the key length and key - err := buf.WriteByte(uint8(len(key))) + err := safePutUint8(buf, len(key)) if err != nil { return nil, fmt.Errorf("%v: %v", errEncoding, err) } @@ -430,11 +476,7 @@ func (m *Message) encodeList(key string, list []string) ([]byte, error) { } // Write the item's length to the buffer as two bytes - il := make([]byte, 2) - binary.BigEndian.PutUint16(il, uint16(len(item))) - - _, err = buf.Write(il) - if err != nil { + if err := safePutUint16(buf, len(item)); err != nil { return nil, fmt.Errorf("%v: %v", errEncoding, err) } @@ -461,7 +503,7 @@ func (m *Message) encodeSection(key string, section *Message) ([]byte, error) { buf := bytes.NewBuffer([]byte{msgSectionStart}) // Write the key length and key - err := buf.WriteByte(uint8(len(key))) + err := safePutUint8(buf, len(key)) if err != nil { return nil, fmt.Errorf("%v: %v", errEncoding, err) } diff --git a/vici/packet.go b/vici/packet.go index ff3dedd..c6649f1 100644 --- a/vici/packet.go +++ b/vici/packet.go @@ -110,7 +110,7 @@ func (p *packet) bytes() ([]byte, error) { // Write the name, preceded by its length if p.isNamed() { - err := buf.WriteByte(uint8(len(p.name))) + err := safePutUint8(buf, len(p.name)) if err != nil { return nil, fmt.Errorf("%v: %v", errPacketWrite, err) } diff --git a/vici/packet_test.go b/vici/packet_test.go index 9a99fe9..c16301d 100644 --- a/vici/packet_test.go +++ b/vici/packet_test.go @@ -117,3 +117,21 @@ func TestPacketBytes(t *testing.T) { t.Fatalf("Encoded packet does not equal gold bytes.\nExpected: %v\nReceived: %v", goldUnnamedPacketBytes, b) } } + +func TestPacketTooLong(t *testing.T) { + tooLong := make([]byte, 256) + + for i := range tooLong { + tooLong[i] = 'a' + } + + p := &packet{ + ptype: pktCmdRequest, + name: string(tooLong), + } + + _, err := p.bytes() + if err == nil { + t.Fatalf("Expected packet-too-long error due to %s", p.name) + } +}