From 853e75b3388b87d62147f69bbc9cc0c47ca15839 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sat, 29 Jun 2024 19:01:39 -0700 Subject: [PATCH 01/16] Switch to net/netip The preformance of this is approximately the same as the net.IP version, except for the methods that return a network. For those, there is a slight improvement. --- deserializer_test.go | 4 +- example_test.go | 16 ++--- go.mod | 2 +- reader.go | 94 ++++++++++++++--------------- reader_test.go | 137 ++++++++++++++++++++++--------------------- traverse.go | 115 ++++++++++++++++++------------------ traverse_test.go | 45 +++++++------- 7 files changed, 207 insertions(+), 206 deletions(-) diff --git a/deserializer_test.go b/deserializer_test.go index c68a9d5..1a483d8 100644 --- a/deserializer_test.go +++ b/deserializer_test.go @@ -2,7 +2,7 @@ package maxminddb import ( "math/big" - "net" + "net/netip" "testing" "github.com/stretchr/testify/require" @@ -13,7 +13,7 @@ func TestDecodingToDeserializer(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) dser := testDeserializer{} - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &dser) require.NoError(t, err, "unexpected error while doing lookup: %v", err) checkDecodingToInterface(t, dser.rv) diff --git a/example_test.go b/example_test.go index 52e878e..fc0917f 100644 --- a/example_test.go +++ b/example_test.go @@ -3,9 +3,9 @@ package maxminddb_test import ( "fmt" "log" - "net" + "net/netip" - "github.com/oschwald/maxminddb-golang" + "github.com/oschwald/maxminddb-golang/v2" ) // This example shows how to decode to a struct. @@ -16,7 +16,7 @@ func ExampleReader_Lookup_struct() { } defer db.Close() - ip := net.ParseIP("81.2.69.142") + addr := netip.MustParseAddr("81.2.69.142") var record struct { Country struct { @@ -24,7 +24,7 @@ func ExampleReader_Lookup_struct() { } `maxminddb:"country"` } // Or any appropriate struct - err = db.Lookup(ip, &record) + err = db.Lookup(addr, &record) if err != nil { log.Panic(err) } @@ -41,10 +41,10 @@ func ExampleReader_Lookup_interface() { } defer db.Close() - ip := net.ParseIP("81.2.69.142") + addr := netip.MustParseAddr("81.2.69.142") var record any - err = db.Lookup(ip, &record) + err = db.Lookup(addr, &record) if err != nil { log.Panic(err) } @@ -114,12 +114,12 @@ func ExampleReader_NetworksWithin() { } defer db.Close() - _, network, err := net.ParseCIDR("1.0.0.0/8") + prefix, err := netip.ParsePrefix("1.0.0.0/8") if err != nil { log.Panic(err) } - networks := db.NetworksWithin(network, maxminddb.SkipAliasedNetworks) + networks := db.NetworksWithin(prefix, maxminddb.SkipAliasedNetworks) for networks.Next() { record := struct { Domain string `maxminddb:"connection_type"` diff --git a/go.mod b/go.mod index d56901d..11b1b10 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/oschwald/maxminddb-golang +module github.com/oschwald/maxminddb-golang/v2 go 1.21 diff --git a/reader.go b/reader.go index 2dc712f..e41558f 100644 --- a/reader.go +++ b/reader.go @@ -5,7 +5,7 @@ import ( "bytes" "errors" "fmt" - "net" + "net/netip" "reflect" ) @@ -110,6 +110,7 @@ func FromBytes(buffer []byte) (*Reader, error) { func (r *Reader) setIPv4Start() { if r.Metadata.IPVersion != 6 { + r.ipv4StartBitDepth = 96 return } @@ -130,7 +131,7 @@ func (r *Reader) setIPv4Start() { // because of type differences, an UnmarshalTypeError is returned. If the // database is invalid or otherwise cannot be read, an InvalidDatabaseError // is returned. -func (r *Reader) Lookup(ip net.IP, result any) error { +func (r *Reader) Lookup(ip netip.Addr, result any) error { if r.buffer == nil { return errors.New("cannot call Lookup on a closed database") } @@ -142,7 +143,7 @@ func (r *Reader) Lookup(ip net.IP, result any) error { } // LookupNetwork retrieves the database record for ip and stores it in the -// value pointed to by result. The network returned is the network associated +// value pointed to by result. The prefix returned is the network associated // with the data record in the database. The ok return value indicates whether // the database contained a record for the ip. // @@ -151,20 +152,21 @@ func (r *Reader) Lookup(ip net.IP, result any) error { // UnmarshalTypeError is returned. If the database is invalid or otherwise // cannot be read, an InvalidDatabaseError is returned. func (r *Reader) LookupNetwork( - ip net.IP, + ip netip.Addr, result any, -) (network *net.IPNet, ok bool, err error) { +) (prefix netip.Prefix, ok bool, err error) { if r.buffer == nil { - return nil, false, errors.New("cannot call Lookup on a closed database") + return netip.Prefix{}, false, errors.New("cannot call Lookup on a closed database") } pointer, prefixLength, ip, err := r.lookupPointer(ip) + // We return this error below as we want to return the prefix it is for - network = r.cidr(ip, prefixLength) - if pointer == 0 || err != nil { - return network, false, err + prefix, errP := r.cidr(ip, prefixLength) + if pointer == 0 || err != nil || errP != nil { + return prefix, false, errors.Join(err, errP) } - return network, true, r.retrieveData(pointer, result) + return prefix, true, r.retrieveData(pointer, result) } // LookupOffset maps an argument net.IP to a corresponding record offset in the @@ -172,7 +174,7 @@ func (r *Reader) LookupNetwork( // otherwise be extracted by passing the returned offset to Decode. LookupOffset // is an advanced API, which exists to provide clients with a means to cache // previously-decoded records. -func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) { +func (r *Reader) LookupOffset(ip netip.Addr) (uintptr, error) { if r.buffer == nil { return 0, errors.New("cannot call LookupOffset on a closed database") } @@ -183,22 +185,28 @@ func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) { return r.resolveDataPointer(pointer) } -func (r *Reader) cidr(ip net.IP, prefixLength int) *net.IPNet { - // This is necessary as the node that the IPv4 start is at may - // be at a bit depth that is less that 96, i.e., ipv4Start points - // to a leaf node. For instance, if a record was inserted at ::/8, - // the ipv4Start would point directly at the leaf node for the - // record and would have a bit depth of 8. This would not happen - // with databases currently distributed by MaxMind as all of them - // have an IPv4 subtree that is greater than a single node. - if r.Metadata.IPVersion == 6 && - len(ip) == net.IPv4len && - r.ipv4StartBitDepth != 96 { - return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(r.ipv4StartBitDepth, 128)} +var zeroIP = netip.MustParseAddr("::") + +func (r *Reader) cidr(ip netip.Addr, prefixLength int) (netip.Prefix, error) { + if ip.Is4() { + // This is necessary as the node that the IPv4 start is at may + // be at a bit depth that is less that 96, i.e., ipv4Start points + // to a leaf node. For instance, if a record was inserted at ::/8, + // the ipv4Start would point directly at the leaf node for the + // record and would have a bit depth of 8. This would not happen + // with databases currently distributed by MaxMind as all of them + // have an IPv4 subtree that is greater than a single node. + if r.Metadata.IPVersion == 6 && r.ipv4StartBitDepth != 96 { + return netip.PrefixFrom(zeroIP, r.ipv4StartBitDepth), nil + } + prefixLength -= 96 } - mask := net.CIDRMask(prefixLength, len(ip)*8) - return &net.IPNet{IP: ip.Mask(mask), Mask: mask} + prefix, err := ip.Prefix(prefixLength) + if err != nil { + return netip.Prefix{}, fmt.Errorf("creating prefix from %s/%d: %w", ip, prefixLength, err) + } + return prefix, nil } // Decode the record at |offset| into |result|. The result value pointed to @@ -239,29 +247,15 @@ func (r *Reader) decode(offset uintptr, result any) error { return err } -func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) { - if ip == nil { - return 0, 0, nil, errors.New("IP passed to Lookup cannot be nil") - } - - ipV4Address := ip.To4() - if ipV4Address != nil { - ip = ipV4Address - } - if len(ip) == 16 && r.Metadata.IPVersion == 4 { +func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) { + if r.Metadata.IPVersion == 4 && ip.Is6() { return 0, 0, ip, fmt.Errorf( "error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", ip.String(), ) } - bitCount := uint(len(ip) * 8) - - var node uint - if bitCount == 32 { - node = r.ipv4Start - } - node, prefixLength := r.traverseTree(ip, node, bitCount) + node, prefixLength := r.traverseTree(ip, 0, 128) nodeCount := r.Metadata.NodeCount if node == nodeCount { @@ -274,12 +268,18 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) { return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree") } -func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) { +func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) { + i := 0 + if ip.Is4() { + i = r.ipv4StartBitDepth + node = r.ipv4Start + } nodeCount := r.Metadata.NodeCount - i := uint(0) - for ; i < bitCount && node < nodeCount; i++ { - bit := uint(1) & (uint(ip[i>>3]) >> (7 - (i % 8))) + ip16 := ip.As16() + + for ; i < stopBit && node < nodeCount; i++ { + bit := uint(1) & (uint(ip16[i>>3]) >> (7 - (i % 8))) offset := node * r.nodeOffsetMult if bit == 0 { @@ -289,7 +289,7 @@ func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) { } } - return node, int(i) + return node, i } func (r *Reader) retrieveData(pointer uint, result any) error { diff --git a/reader_test.go b/reader_test.go index c9d287d..1817db5 100644 --- a/reader_test.go +++ b/reader_test.go @@ -6,6 +6,7 @@ import ( "math/big" "math/rand" "net" + "net/netip" "os" "path/filepath" "testing" @@ -19,19 +20,21 @@ func TestReader(t *testing.T) { for _, recordSize := range []uint{24, 28, 32} { for _, ipVersion := range []uint{4, 6} { fileName := fmt.Sprintf( - testFile("MaxMind-DB-test-ipv%d-%d.mmdb"), + "MaxMind-DB-test-ipv%d-%d.mmdb", ipVersion, recordSize, ) - reader, err := Open(fileName) - require.NoError(t, err, "unexpected error while opening database: %v", err) - checkMetadata(t, reader, ipVersion, recordSize) - - if ipVersion == 4 { - checkIpv4(t, reader) - } else { - checkIpv6(t, reader) - } + t.Run(fileName, func(t *testing.T) { + reader, err := Open(testFile(fileName)) + require.NoError(t, err, "unexpected error while opening database: %v", err) + checkMetadata(t, reader, ipVersion, recordSize) + + if ipVersion == 4 { + checkIpv4(t, reader) + } else { + checkIpv6(t, reader) + } + }) } } } @@ -97,91 +100,91 @@ func TestLookupNetwork(t *testing.T) { } tests := []struct { - IP net.IP + IP netip.Addr DBFile string ExpectedCIDR string ExpectedRecord any ExpectedOK bool }{ { - IP: net.ParseIP("1.1.1.1"), + IP: netip.MustParseAddr("1.1.1.1"), DBFile: "MaxMind-DB-test-ipv6-32.mmdb", ExpectedCIDR: "1.0.0.0/8", ExpectedRecord: nil, ExpectedOK: false, }, { - IP: net.ParseIP("::1:ffff:ffff"), + IP: netip.MustParseAddr("::1:ffff:ffff"), DBFile: "MaxMind-DB-test-ipv6-24.mmdb", ExpectedCIDR: "::1:ffff:ffff/128", ExpectedRecord: map[string]any{"ip": "::1:ffff:ffff"}, ExpectedOK: true, }, { - IP: net.ParseIP("::2:0:1"), + IP: netip.MustParseAddr("::2:0:1"), DBFile: "MaxMind-DB-test-ipv6-24.mmdb", ExpectedCIDR: "::2:0:0/122", ExpectedRecord: map[string]any{"ip": "::2:0:0"}, ExpectedOK: true, }, { - IP: net.ParseIP("1.1.1.1"), + IP: netip.MustParseAddr("1.1.1.1"), DBFile: "MaxMind-DB-test-ipv4-24.mmdb", ExpectedCIDR: "1.1.1.1/32", ExpectedRecord: map[string]any{"ip": "1.1.1.1"}, ExpectedOK: true, }, { - IP: net.ParseIP("1.1.1.3"), + IP: netip.MustParseAddr("1.1.1.3"), DBFile: "MaxMind-DB-test-ipv4-24.mmdb", ExpectedCIDR: "1.1.1.2/31", ExpectedRecord: map[string]any{"ip": "1.1.1.2"}, ExpectedOK: true, }, { - IP: net.ParseIP("1.1.1.3"), + IP: netip.MustParseAddr("1.1.1.3"), DBFile: "MaxMind-DB-test-decoder.mmdb", ExpectedCIDR: "1.1.1.0/24", ExpectedRecord: decoderRecord, ExpectedOK: true, }, { - IP: net.ParseIP("::ffff:1.1.1.128"), + IP: netip.MustParseAddr("::ffff:1.1.1.128"), DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "1.1.1.0/24", + ExpectedCIDR: "::ffff:1.1.1.0/120", ExpectedRecord: decoderRecord, ExpectedOK: true, }, { - IP: net.ParseIP("::1.1.1.128"), + IP: netip.MustParseAddr("::1.1.1.128"), DBFile: "MaxMind-DB-test-decoder.mmdb", ExpectedCIDR: "::101:100/120", ExpectedRecord: decoderRecord, ExpectedOK: true, }, { - IP: net.ParseIP("200.0.2.1"), + IP: netip.MustParseAddr("200.0.2.1"), DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", ExpectedCIDR: "::/64", ExpectedRecord: "::0/64", ExpectedOK: true, }, { - IP: net.ParseIP("::200.0.2.1"), + IP: netip.MustParseAddr("::200.0.2.1"), DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", ExpectedCIDR: "::/64", ExpectedRecord: "::0/64", ExpectedOK: true, }, { - IP: net.ParseIP("0:0:0:0:ffff:ffff:ffff:ffff"), + IP: netip.MustParseAddr("0:0:0:0:ffff:ffff:ffff:ffff"), DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", ExpectedCIDR: "::/64", ExpectedRecord: "::0/64", ExpectedOK: true, }, { - IP: net.ParseIP("ef00::"), + IP: netip.MustParseAddr("ef00::"), DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", ExpectedCIDR: "8000::/1", ExpectedRecord: nil, @@ -209,7 +212,7 @@ func TestDecodingToInterface(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var recordInterface any - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &recordInterface) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &recordInterface) require.NoError(t, err, "unexpected error while doing lookup: %v", err) checkDecodingToInterface(t, recordInterface) @@ -296,13 +299,13 @@ func TestDecoder(t *testing.T) { { // Directly lookup and decode. var result TestType - require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)) verify(result) } { // Lookup record offset, then Decode. var result TestType - offset, err := reader.LookupOffset(net.ParseIP("::1.1.1.0")) + offset, err := reader.LookupOffset(netip.MustParseAddr("::1.1.1.0")) require.NoError(t, err) assert.NotEqual(t, NotFound, offset) @@ -327,7 +330,7 @@ func TestStructInterface(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) - require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)) assert.True(t, result.method()) } @@ -338,7 +341,7 @@ func TestNonEmptyNilInterface(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result) assert.Equal( t, "maxminddb: cannot unmarshal map into type maxminddb.TestInterface", @@ -361,7 +364,7 @@ func TestEmbeddedStructAsInterface(t *testing.T) { db, err := Open(testFile("GeoIP2-ISP-Test.mmdb")) require.NoError(t, err) - require.NoError(t, db.Lookup(net.ParseIP("1.128.0.0"), &result)) + require.NoError(t, db.Lookup(netip.MustParseAddr("1.128.0.0"), &result)) } type BoolInterface interface { @@ -387,7 +390,7 @@ func TestValueTypeInterface(t *testing.T) { // although it would be nice to support cases like this, I am not sure it // is possible to do so in a general way. - assert.Error(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + assert.Error(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)) } type NestedMapX struct { @@ -429,7 +432,7 @@ func TestComplexStructWithNestingAndPointer(t *testing.T) { var result TestPointerType - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result) require.NoError(t, err) assert.Equal(t, []uint{uint(1), uint(2), uint(3)}, *result.Array) @@ -461,7 +464,7 @@ func TestNestedMapDecode(t *testing.T) { var r map[string]map[string]any - require.NoError(t, db.Lookup(net.ParseIP("89.160.20.128"), &r)) + require.NoError(t, db.Lookup(netip.MustParseAddr("89.160.20.128"), &r)) assert.Equal( t, @@ -519,7 +522,7 @@ func TestNestedOffsetDecode(t *testing.T) { db, err := Open(testFile("GeoIP2-City-Test.mmdb")) require.NoError(t, err) - off, err := db.LookupOffset(net.ParseIP("81.2.69.142")) + off, err := db.LookupOffset(netip.MustParseAddr("81.2.69.142")) assert.NotEqual(t, NotFound, off) require.NoError(t, err) @@ -561,7 +564,7 @@ func TestDecodingUint16IntoInt(t *testing.T) { var result struct { Uint16 int `maxminddb:"uint16"` } - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result) require.NoError(t, err) assert.Equal(t, 100, result.Uint16) @@ -572,7 +575,7 @@ func TestIpv6inIpv4(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var result TestType - err = reader.Lookup(net.ParseIP("2001::"), &result) + err = reader.Lookup(netip.MustParseAddr("2001::"), &result) var emptyResult TestType assert.Equal(t, emptyResult, result) @@ -589,7 +592,7 @@ func TestBrokenDoubleDatabase(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var result any - err = reader.Lookup(net.ParseIP("2001:220::"), &result) + err = reader.Lookup(netip.MustParseAddr("2001:220::"), &result) expected := newInvalidDatabaseError( "the MaxMind DB file's data section contains bad data (float 64 size of 2)", @@ -622,20 +625,20 @@ func TestDecodingToNonPointer(t *testing.T) { require.NoError(t, err) var recordInterface any - err = reader.Lookup(net.ParseIP("::1.1.1.0"), recordInterface) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), recordInterface) assert.Equal(t, "result param must be a pointer", err.Error()) require.NoError(t, reader.Close(), "error on close") } -func TestNilLookup(t *testing.T) { - reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) - require.NoError(t, err) +// func TestNilLookup(t *testing.T) { +// reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) +// require.NoError(t, err) - var recordInterface any - err = reader.Lookup(nil, recordInterface) - assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error()) - require.NoError(t, reader.Close(), "error on close") -} +// var recordInterface any +// err = reader.Lookup(nil, recordInterface) +// assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error()) +// require.NoError(t, reader.Close(), "error on close") +// } func TestUsingClosedDatabase(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) @@ -643,11 +646,11 @@ func TestUsingClosedDatabase(t *testing.T) { require.NoError(t, reader.Close()) var recordInterface any - - err = reader.Lookup(nil, recordInterface) + addr := netip.MustParseAddr("::") + err = reader.Lookup(addr, recordInterface) assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) - _, err = reader.LookupOffset(nil) + _, err = reader.LookupOffset(addr) assert.Equal(t, "cannot call LookupOffset on a closed database", err.Error()) err = reader.Decode(0, recordInterface) @@ -682,7 +685,7 @@ func checkMetadata(t *testing.T, reader *Reader, ipVersion, recordSize uint) { func checkIpv4(t *testing.T, reader *Reader) { for i := uint(0); i < 6; i++ { address := fmt.Sprintf("1.1.1.%d", uint(1)<> 24) ip[1] = byte(num >> 16) ip[2] = byte(num >> 8) ip[3] = byte(num) + v, _ := netip.AddrFromSlice(ip) + return v } func testFile(file string) string { diff --git a/traverse.go b/traverse.go index 90073e2..40a53ee 100644 --- a/traverse.go +++ b/traverse.go @@ -2,12 +2,12 @@ package maxminddb import ( "fmt" - "net" + "net/netip" ) // Internal structure used to keep track of nodes we still need to visit. type netNode struct { - ip net.IP + ip netip.Addr bit uint pointer uint } @@ -22,8 +22,8 @@ type Networks struct { } var ( - allIPv4 = &net.IPNet{IP: make(net.IP, 4), Mask: net.CIDRMask(0, 32)} - allIPv6 = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(0, 128)} + allIPv4 = netip.MustParsePrefix("0.0.0.0/0") + allIPv6 = netip.MustParsePrefix("::/0") ) // NetworksOption are options for Networks and NetworksWithin. @@ -58,21 +58,21 @@ func (r *Reader) Networks(options ...NetworksOption) *Networks { } // NetworksWithin returns an iterator that can be used to traverse all networks -// in the database which are contained in a given network. +// in the database which are contained in a given prefix. // // Please note that a MaxMind DB may map IPv4 networks into several locations // in an IPv6 database. This iterator will iterate over all of these locations // separately. To only iterate over the IPv4 networks once, use the // SkipAliasedNetworks option. // -// If the provided network is contained within a network in the database, the +// If the provided prefix is contained within a network in the database, the // iterator will iterate over exactly one network, the containing network. -func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) *Networks { - if r.Metadata.IPVersion == 4 && network.IP.To4() == nil { +func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) *Networks { + if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { return &Networks{ err: fmt.Errorf( "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", - network.String(), + prefix, ), } } @@ -82,29 +82,24 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) * option(networks) } - ip := network.IP - prefixLength, _ := network.Mask.Size() - - if r.Metadata.IPVersion == 6 && len(ip) == net.IPv4len { - if networks.skipAliasedNetworks { - ip = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ip[0], ip[1], ip[2], ip[3]} - } else { - ip = ip.To16() - } - prefixLength += 96 + ip := prefix.Addr() + netIP := ip + stopBit := prefix.Bits() + if ip.Is4() { + netIP = v4ToV16(ip) + stopBit += 96 } - pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) + pointer, bit := r.traverseTree(ip, 0, stopBit) + + prefix, err := netIP.Prefix(bit) + if err != nil { + networks.err = fmt.Errorf("prefixing %s with %d", netIP, bit) + } - // We could skip this when bit >= prefixLength if we assume that the network - // passed in is in canonical form. However, given that this may not be the - // case, it is safest to always take the mask. If this is hot code at some - // point, we could eliminate the allocation of the net.IPMask by zeroing - // out the bits in ip directly. - ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) networks.nodes = []netNode{ { - ip: ip, + ip: prefix.Addr(), bit: uint(bit), pointer: pointer, }, @@ -136,11 +131,17 @@ func (n *Networks) Next() bool { n.lastNode = node return true } - ipRight := make(net.IP, len(node.ip)) - copy(ipRight, node.ip) + ipRight := node.ip.As16() if len(ipRight) <= int(node.bit>>3) { + displayAddr := node.ip + displayBits := node.bit + if isInIPv4Subtree(node.ip) { + displayAddr = v6ToV4(displayAddr) + displayBits -= 96 + } + n.err = newInvalidDatabaseError( - "invalid search tree at %v/%v", ipRight, node.bit) + "invalid search tree at %s/%d", displayAddr, displayBits) return false } ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) @@ -151,7 +152,7 @@ func (n *Networks) Next() bool { node.bit++ n.nodes = append(n.nodes, netNode{ pointer: rightPointer, - ip: ipRight, + ip: netip.AddrFrom16(ipRight), bit: node.bit, }) @@ -165,30 +166,22 @@ func (n *Networks) Next() bool { // Network returns the current network or an error if there is a problem // decoding the data for the network. It takes a pointer to a result value to // decode the network's data into. -func (n *Networks) Network(result any) (*net.IPNet, error) { +func (n *Networks) Network(result any) (netip.Prefix, error) { if n.err != nil { - return nil, n.err + return netip.Prefix{}, n.err } if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { - return nil, err + return netip.Prefix{}, err } ip := n.lastNode.ip prefixLength := int(n.lastNode.bit) - - // We do this because uses of SkipAliasedNetworks expect the IPv4 networks - // to be returned as IPv4 networks. If we are not skipping aliased - // networks, then the user will get IPv4 networks from the ::FFFF:0:0/96 - // network as Go automatically converts those. - if n.skipAliasedNetworks && isInIPv4Subtree(ip) { - ip = ip[12:] + if isInIPv4Subtree(ip) { + ip = v6ToV4(ip) prefixLength -= 96 } - return &net.IPNet{ - IP: ip, - Mask: net.CIDRMask(prefixLength, len(ip)*8), - }, nil + return netip.PrefixFrom(ip, prefixLength), nil } // Err returns an error, if any, that was encountered during iteration. @@ -196,16 +189,24 @@ func (n *Networks) Err() error { return n.err } -// isInIPv4Subtree returns true if the IP is an IPv6 address in the database's -// IPv4 subtree. -func isInIPv4Subtree(ip net.IP) bool { - if len(ip) != 16 { - return false - } - for i := 0; i < 12; i++ { - if ip[i] != 0 { - return false - } - } - return true +var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next() + +// isInIPv4Subtree returns true if the IP is in the database's IPv4 subtree. +func isInIPv4Subtree(ip netip.Addr) bool { + return ip.Is4() || ip.Less(ipv4SubtreeBoundary) +} + +// We store IPv4 addresses at ::/96 for unclear reasons. +func v4ToV16(ip netip.Addr) netip.Addr { + b4 := ip.As4() + var b16 [16]byte + copy(b16[12:], b4[:]) + return netip.AddrFrom16(b16) +} + +// Converts an IPv4 address embedded in IPv6 to IPv4. +func v6ToV4(ip netip.Addr) netip.Addr { + b := ip.As16() + v, _ := netip.AddrFromSlice(b[12:]) + return v } diff --git a/traverse_test.go b/traverse_test.go index 00edfce..c95c33d 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -2,7 +2,7 @@ package maxminddb import ( "fmt" - "net" + "net/netip" "strconv" "strings" "testing" @@ -27,8 +27,8 @@ func TestNetworks(t *testing.T) { }{} network, err := n.Network(&record) require.NoError(t, err) - assert.Equal(t, record.IP, network.IP.String(), - "expected %s got %s", record.IP, network.IP.String(), + assert.Equal(t, record.IP, network.Addr().String(), + "expected %s got %s", record.IP, network.Addr().String(), ) } require.NoError(t, n.Err()) @@ -198,23 +198,23 @@ var tests = []networkTest{ Network: "::/0", Database: "mixed", Expected: []string{ - "::101:101/128", - "::101:102/127", - "::101:104/126", - "::101:108/125", - "::101:110/124", - "::101:120/128", - "::1:ffff:ffff/128", - "::2:0:0/122", - "::2:0:40/124", - "::2:0:50/125", - "::2:0:58/127", "1.1.1.1/32", "1.1.1.2/31", "1.1.1.4/30", "1.1.1.8/29", "1.1.1.16/28", "1.1.1.32/32", + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + "::ffff:1.1.1.1/128", + "::ffff:1.1.1.2/127", + "::ffff:1.1.1.4/126", + "::ffff:1.1.1.8/125", + "::ffff:1.1.1.16/124", + "::ffff:1.1.1.32/128", "2001:0:101:101::/64", "2001:0:101:102::/63", "2001:0:101:104::/62", @@ -281,17 +281,12 @@ func TestNetworksWithin(t *testing.T) { // We are purposely not using net.ParseCIDR so that we can pass in // values that aren't in canonical form. parts := strings.Split(v.Network, "/") - ip := net.ParseIP(parts[0]) - if v := ip.To4(); v != nil { - ip = v - } + ip, err := netip.ParseAddr(parts[0]) + require.NoError(t, err) prefixLength, err := strconv.Atoi(parts[1]) require.NoError(t, err) - mask := net.CIDRMask(prefixLength, len(ip)*8) - network := &net.IPNet{ - IP: ip, - Mask: mask, - } + network, err := ip.Prefix(prefixLength) + require.NoError(t, err) require.NoError(t, err) n := reader.NetworksWithin(network, v.Options...) @@ -333,9 +328,9 @@ func TestGeoIPNetworksWithin(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - _, network, err := net.ParseCIDR(v.Network) + prefix, err := netip.ParsePrefix(v.Network) require.NoError(t, err) - n := reader.NetworksWithin(network) + n := reader.NetworksWithin(prefix) var innerIPs []string for n.Next() { From 4a5a9d6250e40a19e64d3c9bd384f9f0a1550208 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Thu, 4 Jul 2024 12:32:10 -0700 Subject: [PATCH 02/16] Skip aliased networks by default --- example_test.go | 4 ++-- traverse.go | 29 +++++++++++++---------------- traverse_test.go | 6 +----- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/example_test.go b/example_test.go index fc0917f..de49cb6 100644 --- a/example_test.go +++ b/example_test.go @@ -63,7 +63,7 @@ func ExampleReader_Networks() { } defer db.Close() - networks := db.Networks(maxminddb.SkipAliasedNetworks) + networks := db.Networks() for networks.Next() { record := struct { Domain string `maxminddb:"connection_type"` @@ -119,7 +119,7 @@ func ExampleReader_NetworksWithin() { log.Panic(err) } - networks := db.NetworksWithin(prefix, maxminddb.SkipAliasedNetworks) + networks := db.NetworksWithin(prefix) for networks.Next() { record := struct { Domain string `maxminddb:"connection_type"` diff --git a/traverse.go b/traverse.go index 40a53ee..67a5148 100644 --- a/traverse.go +++ b/traverse.go @@ -14,11 +14,11 @@ type netNode struct { // Networks represents a set of subnets that we are iterating over. type Networks struct { - err error - reader *Reader - nodes []netNode - lastNode netNode - skipAliasedNetworks bool + err error + reader *Reader + nodes []netNode + lastNode netNode + includeAliasedNetworks bool } var ( @@ -29,23 +29,20 @@ var ( // NetworksOption are options for Networks and NetworksWithin. type NetworksOption func(*Networks) -// SkipAliasedNetworks is an option for Networks and NetworksWithin that -// makes them not iterate over aliases of the IPv4 subtree in an IPv6 +// IncludeAliasedNetworks is an option for Networks and NetworksWithin +// that makes them iterate over aliases of the IPv4 subtree in an IPv6 // database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. -// -// You most likely want to set this. The only reason it isn't the default -// behavior is to provide backwards compatibility to existing users. -func SkipAliasedNetworks(networks *Networks) { - networks.skipAliasedNetworks = true +func IncludeAliasedNetworks(networks *Networks) { + networks.includeAliasedNetworks = true } // Networks returns an iterator that can be used to traverse all networks in // the database. // // Please note that a MaxMind DB may map IPv4 networks into several locations -// in an IPv6 database. This iterator will iterate over all of these locations -// separately. To only iterate over the IPv4 networks once, use the -// SkipAliasedNetworks option. +// in an IPv6 database. This iterator will only iterate over these once by +// default. To iterate over all the IPv4 network locations, use the +// IncludeAliasedNetworks option. func (r *Reader) Networks(options ...NetworksOption) *Networks { var networks *Networks if r.Metadata.IPVersion == 6 { @@ -122,7 +119,7 @@ func (n *Networks) Next() bool { for node.pointer != n.reader.Metadata.NodeCount { // This skips IPv4 aliases without hardcoding the networks that the writer // currently aliases. - if n.skipAliasedNetworks && n.reader.ipv4Start != 0 && + if !n.includeAliasedNetworks && n.reader.ipv4Start != 0 && node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) { break } diff --git a/traverse_test.go b/traverse_test.go index c95c33d..382659a 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -140,7 +140,6 @@ var tests = []networkTest{ Expected: []string{ "::1:ffff:ffff/128", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "::/0", @@ -152,7 +151,6 @@ var tests = []networkTest{ "::2:0:50/125", "::2:0:58/127", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "::2:0:40/123", @@ -162,7 +160,6 @@ var tests = []networkTest{ "::2:0:50/125", "::2:0:58/127", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "0:0:0:0:0:ffff:ffff:ff00/120", @@ -192,7 +189,6 @@ var tests = []networkTest{ "1.1.1.16/28", "1.1.1.32/32", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "::/0", @@ -228,6 +224,7 @@ var tests = []networkTest{ "2002:101:110::/44", "2002:101:120::/48", }, + Options: []NetworksOption{IncludeAliasedNetworks}, }, { Network: "::/0", @@ -245,7 +242,6 @@ var tests = []networkTest{ "::2:0:50/125", "::2:0:58/127", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "1.1.1.16/28", From b46d987749b1d96d5f5c7847484107445b55bbd9 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Fri, 5 Jul 2024 13:17:04 -0700 Subject: [PATCH 03/16] Make Lookup return a Result This makes it easier to extend without adding many different lookup methods. --- decoder_test.go | 4 +-- deserializer_test.go | 2 +- example_test.go | 4 +-- reader.go | 48 ++++++++++++++--------------------- reader_test.go | 46 ++++++++++++++++----------------- result.go | 60 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 107 insertions(+), 57 deletions(-) create mode 100644 result.go diff --git a/decoder_test.go b/decoder_test.go index 68d8903..c03e380 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -207,7 +207,7 @@ func validateDecoding(t *testing.T, tests map[string]any) { for inputStr, expected := range tests { inputBytes, err := hex.DecodeString(inputStr) require.NoError(t, err) - d := decoder{inputBytes} + d := decoder{buffer: inputBytes} var result any _, err = d.decode(0, reflect.ValueOf(&result), 0) @@ -223,7 +223,7 @@ func validateDecoding(t *testing.T, tests map[string]any) { func TestPointers(t *testing.T) { bytes, err := os.ReadFile(testFile("maps-with-pointers.raw")) require.NoError(t, err) - d := decoder{bytes} + d := decoder{buffer: bytes} expected := map[uint]map[string]string{ 0: {"long_key": "long_value1"}, diff --git a/deserializer_test.go b/deserializer_test.go index 1a483d8..a6e3b70 100644 --- a/deserializer_test.go +++ b/deserializer_test.go @@ -13,7 +13,7 @@ func TestDecodingToDeserializer(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) dser := testDeserializer{} - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &dser) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&dser) require.NoError(t, err, "unexpected error while doing lookup: %v", err) checkDecodingToInterface(t, dser.rv) diff --git a/example_test.go b/example_test.go index de49cb6..dfad270 100644 --- a/example_test.go +++ b/example_test.go @@ -24,7 +24,7 @@ func ExampleReader_Lookup_struct() { } `maxminddb:"country"` } // Or any appropriate struct - err = db.Lookup(addr, &record) + err = db.Lookup(addr).Decode(&record) if err != nil { log.Panic(err) } @@ -44,7 +44,7 @@ func ExampleReader_Lookup_interface() { addr := netip.MustParseAddr("81.2.69.142") var record any - err = db.Lookup(addr, &record) + err = db.Lookup(addr).Decode(&record) if err != nil { log.Panic(err) } diff --git a/reader.go b/reader.go index e41558f..b29d5c6 100644 --- a/reader.go +++ b/reader.go @@ -61,7 +61,7 @@ func FromBytes(buffer []byte) (*Reader, error) { } metadataStart += len(metadataStartMarker) - metadataDecoder := decoder{buffer[metadataStart:]} + metadataDecoder := decoder{buffer: buffer[metadataStart:]} var metadata Metadata @@ -78,7 +78,7 @@ func FromBytes(buffer []byte) (*Reader, error) { return nil, newInvalidDatabaseError("the MaxMind DB contains invalid metadata") } d := decoder{ - buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], + buffer: buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], } nodeBuffer := buffer[:searchTreeSize] @@ -125,21 +125,25 @@ func (r *Reader) setIPv4Start() { r.ipv4StartBitDepth = i } -// Lookup retrieves the database record for ip and stores it in the value -// pointed to by result. If result is nil or not a pointer, an error is -// returned. If the data in the database record cannot be stored in result -// because of type differences, an UnmarshalTypeError is returned. If the -// database is invalid or otherwise cannot be read, an InvalidDatabaseError -// is returned. -func (r *Reader) Lookup(ip netip.Addr, result any) error { +// Lookup retrieves the database record for ip and returns Result, which can +// be used to decode the data.. +func (r *Reader) Lookup(ip netip.Addr) Result { if r.buffer == nil { - return errors.New("cannot call Lookup on a closed database") + return Result{err: errors.New("cannot call Lookup on a closed database")} } pointer, _, _, err := r.lookupPointer(ip) - if pointer == 0 || err != nil { - return err + if err != nil { + return Result{err: err} + } + if pointer == 0 { + return Result{offset: notFound} + } + offset, err := r.resolveDataPointer(pointer) + return Result{ + decoder: r.decoder, + offset: uint(offset), + err: err, } - return r.retrieveData(pointer, result) } // LookupNetwork retrieves the database record for ip and stores it in the @@ -229,22 +233,8 @@ func (r *Reader) Decode(offset uintptr, result any) error { if r.buffer == nil { return errors.New("cannot call Decode on a closed database") } - return r.decode(offset, result) -} - -func (r *Reader) decode(offset uintptr, result any) error { - rv := reflect.ValueOf(result) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return errors.New("result param must be a pointer") - } - - if dser, ok := result.(deserializer); ok { - _, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0, false) - return err - } - _, err := r.decoder.decode(uint(offset), rv, 0) - return err + return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) } func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) { @@ -297,7 +287,7 @@ func (r *Reader) retrieveData(pointer uint, result any) error { if err != nil { return err } - return r.decode(offset, result) + return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) } func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { diff --git a/reader_test.go b/reader_test.go index 1817db5..e0f5a30 100644 --- a/reader_test.go +++ b/reader_test.go @@ -212,7 +212,7 @@ func TestDecodingToInterface(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var recordInterface any - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &recordInterface) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&recordInterface) require.NoError(t, err, "unexpected error while doing lookup: %v", err) checkDecodingToInterface(t, recordInterface) @@ -299,7 +299,7 @@ func TestDecoder(t *testing.T) { { // Directly lookup and decode. var result TestType - require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)) verify(result) } { @@ -330,7 +330,7 @@ func TestStructInterface(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) - require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)) assert.True(t, result.method()) } @@ -341,7 +341,7 @@ func TestNonEmptyNilInterface(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) assert.Equal( t, "maxminddb: cannot unmarshal map into type maxminddb.TestInterface", @@ -364,7 +364,7 @@ func TestEmbeddedStructAsInterface(t *testing.T) { db, err := Open(testFile("GeoIP2-ISP-Test.mmdb")) require.NoError(t, err) - require.NoError(t, db.Lookup(netip.MustParseAddr("1.128.0.0"), &result)) + require.NoError(t, db.Lookup(netip.MustParseAddr("1.128.0.0")).Decode(&result)) } type BoolInterface interface { @@ -390,7 +390,7 @@ func TestValueTypeInterface(t *testing.T) { // although it would be nice to support cases like this, I am not sure it // is possible to do so in a general way. - assert.Error(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)) + assert.Error(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)) } type NestedMapX struct { @@ -432,7 +432,7 @@ func TestComplexStructWithNestingAndPointer(t *testing.T) { var result TestPointerType - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) require.NoError(t, err) assert.Equal(t, []uint{uint(1), uint(2), uint(3)}, *result.Array) @@ -464,7 +464,7 @@ func TestNestedMapDecode(t *testing.T) { var r map[string]map[string]any - require.NoError(t, db.Lookup(netip.MustParseAddr("89.160.20.128"), &r)) + require.NoError(t, db.Lookup(netip.MustParseAddr("89.160.20.128")).Decode(&r)) assert.Equal( t, @@ -564,7 +564,7 @@ func TestDecodingUint16IntoInt(t *testing.T) { var result struct { Uint16 int `maxminddb:"uint16"` } - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) require.NoError(t, err) assert.Equal(t, 100, result.Uint16) @@ -575,7 +575,7 @@ func TestIpv6inIpv4(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var result TestType - err = reader.Lookup(netip.MustParseAddr("2001::"), &result) + err = reader.Lookup(netip.MustParseAddr("2001::")).Decode(&result) var emptyResult TestType assert.Equal(t, emptyResult, result) @@ -592,7 +592,7 @@ func TestBrokenDoubleDatabase(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var result any - err = reader.Lookup(netip.MustParseAddr("2001:220::"), &result) + err = reader.Lookup(netip.MustParseAddr("2001:220::")).Decode(&result) expected := newInvalidDatabaseError( "the MaxMind DB file's data section contains bad data (float 64 size of 2)", @@ -625,7 +625,7 @@ func TestDecodingToNonPointer(t *testing.T) { require.NoError(t, err) var recordInterface any - err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), recordInterface) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(recordInterface) assert.Equal(t, "result param must be a pointer", err.Error()) require.NoError(t, reader.Close(), "error on close") } @@ -635,7 +635,7 @@ func TestDecodingToNonPointer(t *testing.T) { // require.NoError(t, err) // var recordInterface any -// err = reader.Lookup(nil, recordInterface) +// err = reader.Lookup(nil).Decode( recordInterface) // assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error()) // require.NoError(t, reader.Close(), "error on close") // } @@ -647,7 +647,7 @@ func TestUsingClosedDatabase(t *testing.T) { var recordInterface any addr := netip.MustParseAddr("::") - err = reader.Lookup(addr, recordInterface) + err = reader.Lookup(addr).Decode(recordInterface) assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) _, err = reader.LookupOffset(addr) @@ -688,7 +688,7 @@ func checkIpv4(t *testing.T, reader *Reader) { ip := netip.MustParseAddr(address) var result map[string]string - err := reader.Lookup(ip, &result) + err := reader.Lookup(ip).Decode(&result) require.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Equal(t, map[string]string{"ip": address}, result) } @@ -708,7 +708,7 @@ func checkIpv4(t *testing.T, reader *Reader) { ip := netip.MustParseAddr(keyAddress) var result map[string]string - err := reader.Lookup(ip, &result) + err := reader.Lookup(ip).Decode(&result) require.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Equal(t, data, result) } @@ -717,7 +717,7 @@ func checkIpv4(t *testing.T, reader *Reader) { ip := netip.MustParseAddr(address) var result map[string]string - err := reader.Lookup(ip, &result) + err := reader.Lookup(ip).Decode(&result) require.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Nil(t, result) } @@ -731,7 +731,7 @@ func checkIpv6(t *testing.T, reader *Reader) { for _, address := range subnets { var result map[string]string - err := reader.Lookup(netip.MustParseAddr(address), &result) + err := reader.Lookup(netip.MustParseAddr(address)).Decode(&result) require.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Equal(t, map[string]string{"ip": address}, result) } @@ -750,14 +750,14 @@ func checkIpv6(t *testing.T, reader *Reader) { for keyAddress, valueAddress := range pairs { data := map[string]string{"ip": valueAddress} var result map[string]string - err := reader.Lookup(netip.MustParseAddr(keyAddress), &result) + err := reader.Lookup(netip.MustParseAddr(keyAddress)).Decode(&result) require.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Equal(t, data, result) } for _, address := range []string{"1.1.1.33", "255.254.253.123", "89fa::"} { var result map[string]string - err := reader.Lookup(netip.MustParseAddr(address), &result) + err := reader.Lookup(netip.MustParseAddr(address)).Decode(&result) require.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Nil(t, result) } @@ -787,7 +787,7 @@ func BenchmarkInterfaceLookup(b *testing.B) { s := make(net.IP, 4) for i := 0; i < b.N; i++ { ip := randomIPv4Address(r, s) - err = db.Lookup(ip, &result) + err = db.Lookup(ip).Decode(&result) if err != nil { b.Error(err) } @@ -875,7 +875,7 @@ func BenchmarkCityLookup(b *testing.B) { s := make(net.IP, 4) for i := 0; i < b.N; i++ { ip := randomIPv4Address(r, s) - err = db.Lookup(ip, &result) + err = db.Lookup(ip).Decode(&result) if err != nil { b.Error(err) } @@ -919,7 +919,7 @@ func BenchmarkCountryCode(b *testing.B) { s := make(net.IP, 4) for i := 0; i < b.N; i++ { ip := randomIPv4Address(r, s) - err = db.Lookup(ip, &result) + err = db.Lookup(ip).Decode(&result) if err != nil { b.Error(err) } diff --git a/result.go b/result.go new file mode 100644 index 0000000..e4426bb --- /dev/null +++ b/result.go @@ -0,0 +1,60 @@ +package maxminddb + +import ( + "errors" + "math" + "reflect" +) + +const notFound uint = math.MaxUint + +type Result struct { + err error + decoder decoder + offset uint +} + +// Decode unmarshals the data from the data section into the value pointed to +// by v. If v is nil or not a pointer, an error is returned. If the data in +// the database record cannot be stored in v because of type differences, an +// UnmarshalTypeError is returned. If the database is invalid or otherwise +// cannot be read, an InvalidDatabaseError is returned. +// +// An error will also be returned if there was an error during the +// Reader.Lookup call. +// +// If the Reader.Lookup call did not find a value for the IP address, no error +// will be returned and v will be unchanged. +func (r Result) Decode(v any) error { + if r.err != nil { + return r.err + } + if r.offset == notFound { + return nil + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("result param must be a pointer") + } + + if dser, ok := v.(deserializer); ok { + _, err := r.decoder.decodeToDeserializer(r.offset, dser, 0, false) + return err + } + + _, err := r.decoder.decode(r.offset, rv, 0) + return err +} + +// Err provides a way to check whether there was an error during the lookup +// without clling Result.Decode. If there was an error, it will also be +// returned from Result.Decode. +func (r Result) Err() error { + return r.err +} + +// Found will return true if the IP was found in the search tree. It will +// return false if the IP was not found or if there was an error. +func (r Result) Found() bool { + return r.err == nil && r.offset != notFound +} From 8da779a9b89a60231fe24642167b1c67c574a0c7 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sat, 6 Jul 2024 16:21:42 -0700 Subject: [PATCH 04/16] Add the ability to decode a single path This is more ergonomic than creating a struct for a single value and the performance is better as well due to less reflection: BenchmarkDecodeCountryCodeWithStruct-8 1347441 882.4 ns/op 1 B/op 0 allocs/op BenchmarkDecodePathCountryCode-8 2708011 445.1 ns/op 1 B/op 0 allocs/op --- decoder.go | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++ reader_test.go | 47 ++++++++++++++++++++++++++++++- result.go | 45 ++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 1 deletion(-) diff --git a/decoder.go b/decoder.go index 435591e..db941fc 100644 --- a/decoder.go +++ b/decoder.go @@ -89,6 +89,82 @@ func (d *decoder) decodeToDeserializer( return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1) } +func (d *decoder) decodePath( + offset uint, + path []any, + result reflect.Value, +) error { +PATH: + for i, v := range path { + var ( + typeNum dataType + size uint + err error + ) + typeNum, size, offset, err = d.decodeCtrlData(offset) + if err != nil { + return err + } + + if typeNum == _Pointer { + pointer, _, err := d.decodePointer(size, offset) + if err != nil { + return err + } + + typeNum, size, offset, err = d.decodeCtrlData(pointer) + if err != nil { + return err + } + } + + switch v := v.(type) { + case string: + // We are expecting a map + if typeNum != _Map { + // XXX - use type names in errors. + return fmt.Errorf("expected a map for %s but found %d", v, typeNum) + } + for i := uint(0); i < size; i++ { + var key []byte + key, offset, err = d.decodeKey(offset) + if err != nil { + return err + } + if string(key) == v { + continue PATH + } + offset, err = d.nextValueOffset(offset, 1) + if err != nil { + return err + } + } + // Not found. Maybe return a boolean? + return nil + case int: + // We are expecting an array + if typeNum != _Slice { + // XXX - use type names in errors. + return fmt.Errorf("expected a slice for %d but found %d", v, typeNum) + } + if size < uint(v) { + // Slice is smaller than index, not found + return nil + } + // TODO: support negative indexes? Seems useful for subdivisions in + // particular. + offset, err = d.nextValueOffset(offset, uint(v)) + if err != nil { + return err + } + default: + return fmt.Errorf("unexpected type for %d value in path, %v: %T", i, v, v) + } + } + _, err := d.decode(offset, result, len(path)) + return err +} + func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) { newOffset := offset + 1 if offset >= uint(len(d.buffer)) { diff --git a/reader_test.go b/reader_test.go index e0f5a30..1c222f6 100644 --- a/reader_test.go +++ b/reader_test.go @@ -316,6 +316,30 @@ func TestDecoder(t *testing.T) { require.NoError(t, reader.Close()) } +func TestDecodePath(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + + result := reader.Lookup(netip.MustParseAddr("::1.1.1.0")) + require.NoError(t, result.Err()) + + var u16 uint16 + + require.NoError(t, result.DecodePath(&u16, "uint16")) + + assert.Equal(t, uint16(100), u16) + + var u uint + require.NoError(t, result.DecodePath(&u, "array", 0)) + assert.Equal(t, uint(1), u) + + require.NoError(t, result.DecodePath(&u, "array", 2)) + assert.Equal(t, uint(3), u) + + require.NoError(t, result.DecodePath(&u, "map", "mapX", "arrayX", 1)) + assert.Equal(t, uint(8), u) +} + type TestInterface interface { method() bool } @@ -902,7 +926,7 @@ func BenchmarkCityLookupNetwork(b *testing.B) { require.NoError(b, db.Close(), "error on close") } -func BenchmarkCountryCode(b *testing.B) { +func BenchmarkDecodeCountryCodeWithStruct(b *testing.B) { db, err := Open("GeoLite2-City.mmdb") require.NoError(b, err) @@ -927,6 +951,27 @@ func BenchmarkCountryCode(b *testing.B) { require.NoError(b, db.Close(), "error on close") } +func BenchmarkDecodePathCountryCode(b *testing.B) { + db, err := Open("GeoLite2-City.mmdb") + require.NoError(b, err) + + path := []any{"country", "iso_code"} + + //nolint:gosec // this is a test + r := rand.New(rand.NewSource(0)) + var result string + + s := make(net.IP, 4) + for i := 0; i < b.N; i++ { + ip := randomIPv4Address(r, s) + err = db.Lookup(ip).DecodePath(&result, path...) + if err != nil { + b.Error(err) + } + } + require.NoError(b, db.Close(), "error on close") +} + func randomIPv4Address(r *rand.Rand, ip []byte) netip.Addr { num := r.Uint32() ip[0] = byte(num >> 24) diff --git a/result.go b/result.go index e4426bb..ba66f68 100644 --- a/result.go +++ b/result.go @@ -46,6 +46,51 @@ func (r Result) Decode(v any) error { return err } +// DecodePath unmarshals a value from data section into v, following the +// specified path. +// +// The v parameter should be a pointer to the value where the decoded data +// will be stored. If v is nil or not a pointer, an error is returned. If the +// data in the database record cannot be stored in v because of type +// differences, an UnmarshalTypeError is returned. +// +// The path is a variadic list of keys (strings) and/or indices (ints) that +// describe the nested structure to traverse in the data to reach the desired +// value. +// +// For maps, string path elements are used as keys. +// For arrays, int path elements are used as indices. +// +// If the path is empty, the entire data structure is decoded into v. +// +// Returns an error if: +// - the path is invalid +// - the data cannot be decoded into the type of v +// - v is not a pointer or the database record cannot be stored in v due to +// type mismatch +// - the Result does not contain valid data +// +// Example usage: +// +// var city string +// err := result.DecodePath(&city, "location", "city", "names", "en") +// +// var geonameID int +// err := result.DecodePath(&geonameID, "subdivisions", 0, "geoname_id") +func (r Result) DecodePath(v any, path ...any) error { + if r.err != nil { + return r.err + } + if r.offset == notFound { + return nil + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("result param must be a pointer") + } + return r.decoder.decodePath(r.offset, path, rv) +} + // Err provides a way to check whether there was an error during the lookup // without clling Result.Decode. If there was an error, it will also be // returned from Result.Decode. From 800e54ab23afd4531bca5f947b55eeedf75a8807 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sat, 6 Jul 2024 16:25:41 -0700 Subject: [PATCH 05/16] Disable gocognit linter It is somewhat arbitrary, especially when dealing with switch statements and code may be inlined for performance reasons. --- .golangci.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/.golangci.toml b/.golangci.toml index 799416c..0abb5aa 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -18,6 +18,7 @@ disable = [ "forcetypeassert", "funlen", "gochecknoglobals", + "gocognit", "godox", "gomnd", "inamedparam", From 3987a4b3b2ec19b5c7237e5e0f3a0677deda5a6d Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 7 Jul 2024 13:23:14 -0700 Subject: [PATCH 06/16] Replace LookupNetwork with Network on Result --- reader.go | 83 +++++++------------------- reader_test.go | 157 +++++++++++++++++++++++++------------------------ result.go | 33 ++++++++++- 3 files changed, 132 insertions(+), 141 deletions(-) diff --git a/reader.go b/reader.go index b29d5c6..6893b36 100644 --- a/reader.go +++ b/reader.go @@ -131,46 +131,29 @@ func (r *Reader) Lookup(ip netip.Addr) Result { if r.buffer == nil { return Result{err: errors.New("cannot call Lookup on a closed database")} } - pointer, _, _, err := r.lookupPointer(ip) + pointer, prefixLen, err := r.lookupPointer(ip) if err != nil { - return Result{err: err} + return Result{ + ip: ip, + prefixLen: uint8(prefixLen), + err: err, + } } if pointer == 0 { - return Result{offset: notFound} + return Result{ + ip: ip, + prefixLen: uint8(prefixLen), + offset: notFound, + } } offset, err := r.resolveDataPointer(pointer) return Result{ - decoder: r.decoder, - offset: uint(offset), - err: err, - } -} - -// LookupNetwork retrieves the database record for ip and stores it in the -// value pointed to by result. The prefix returned is the network associated -// with the data record in the database. The ok return value indicates whether -// the database contained a record for the ip. -// -// If result is nil or not a pointer, an error is returned. If the data in the -// database record cannot be stored in result because of type differences, an -// UnmarshalTypeError is returned. If the database is invalid or otherwise -// cannot be read, an InvalidDatabaseError is returned. -func (r *Reader) LookupNetwork( - ip netip.Addr, - result any, -) (prefix netip.Prefix, ok bool, err error) { - if r.buffer == nil { - return netip.Prefix{}, false, errors.New("cannot call Lookup on a closed database") - } - pointer, prefixLength, ip, err := r.lookupPointer(ip) - // We return this error below as we want to return the prefix it is for - - prefix, errP := r.cidr(ip, prefixLength) - if pointer == 0 || err != nil || errP != nil { - return prefix, false, errors.Join(err, errP) + decoder: r.decoder, + ip: ip, + offset: uint(offset), + prefixLen: uint8(prefixLen), + err: err, } - - return prefix, true, r.retrieveData(pointer, result) } // LookupOffset maps an argument net.IP to a corresponding record offset in the @@ -182,7 +165,7 @@ func (r *Reader) LookupOffset(ip netip.Addr) (uintptr, error) { if r.buffer == nil { return 0, errors.New("cannot call LookupOffset on a closed database") } - pointer, _, _, err := r.lookupPointer(ip) + pointer, _, err := r.lookupPointer(ip) if pointer == 0 || err != nil { return NotFound, err } @@ -191,28 +174,6 @@ func (r *Reader) LookupOffset(ip netip.Addr) (uintptr, error) { var zeroIP = netip.MustParseAddr("::") -func (r *Reader) cidr(ip netip.Addr, prefixLength int) (netip.Prefix, error) { - if ip.Is4() { - // This is necessary as the node that the IPv4 start is at may - // be at a bit depth that is less that 96, i.e., ipv4Start points - // to a leaf node. For instance, if a record was inserted at ::/8, - // the ipv4Start would point directly at the leaf node for the - // record and would have a bit depth of 8. This would not happen - // with databases currently distributed by MaxMind as all of them - // have an IPv4 subtree that is greater than a single node. - if r.Metadata.IPVersion == 6 && r.ipv4StartBitDepth != 96 { - return netip.PrefixFrom(zeroIP, r.ipv4StartBitDepth), nil - } - prefixLength -= 96 - } - - prefix, err := ip.Prefix(prefixLength) - if err != nil { - return netip.Prefix{}, fmt.Errorf("creating prefix from %s/%d: %w", ip, prefixLength, err) - } - return prefix, nil -} - // Decode the record at |offset| into |result|. The result value pointed to // must be a data value that corresponds to a record in the database. This may // include a struct representation of the data, a map capable of holding the @@ -237,9 +198,9 @@ func (r *Reader) Decode(offset uintptr, result any) error { return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) } -func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) { +func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { if r.Metadata.IPVersion == 4 && ip.Is6() { - return 0, 0, ip, fmt.Errorf( + return 0, 0, fmt.Errorf( "error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", ip.String(), ) @@ -250,12 +211,12 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) { nodeCount := r.Metadata.NodeCount if node == nodeCount { // Record is empty - return 0, prefixLength, ip, nil + return 0, prefixLength, nil } else if node > nodeCount { - return node, prefixLength, ip, nil + return node, prefixLength, nil } - return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree") + return 0, prefixLength, newInvalidDatabaseError("invalid node in search tree") } func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) { diff --git a/reader_test.go b/reader_test.go index 1c222f6..ea48b5f 100644 --- a/reader_test.go +++ b/reader_test.go @@ -100,95 +100,95 @@ func TestLookupNetwork(t *testing.T) { } tests := []struct { - IP netip.Addr - DBFile string - ExpectedCIDR string - ExpectedRecord any - ExpectedOK bool + IP netip.Addr + DBFile string + ExpectedNetwork string + ExpectedRecord any + ExpectedFound bool }{ { - IP: netip.MustParseAddr("1.1.1.1"), - DBFile: "MaxMind-DB-test-ipv6-32.mmdb", - ExpectedCIDR: "1.0.0.0/8", - ExpectedRecord: nil, - ExpectedOK: false, + IP: netip.MustParseAddr("1.1.1.1"), + DBFile: "MaxMind-DB-test-ipv6-32.mmdb", + ExpectedNetwork: "1.0.0.0/8", + ExpectedRecord: nil, + ExpectedFound: false, }, { - IP: netip.MustParseAddr("::1:ffff:ffff"), - DBFile: "MaxMind-DB-test-ipv6-24.mmdb", - ExpectedCIDR: "::1:ffff:ffff/128", - ExpectedRecord: map[string]any{"ip": "::1:ffff:ffff"}, - ExpectedOK: true, + IP: netip.MustParseAddr("::1:ffff:ffff"), + DBFile: "MaxMind-DB-test-ipv6-24.mmdb", + ExpectedNetwork: "::1:ffff:ffff/128", + ExpectedRecord: map[string]any{"ip": "::1:ffff:ffff"}, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("::2:0:1"), - DBFile: "MaxMind-DB-test-ipv6-24.mmdb", - ExpectedCIDR: "::2:0:0/122", - ExpectedRecord: map[string]any{"ip": "::2:0:0"}, - ExpectedOK: true, + IP: netip.MustParseAddr("::2:0:1"), + DBFile: "MaxMind-DB-test-ipv6-24.mmdb", + ExpectedNetwork: "::2:0:0/122", + ExpectedRecord: map[string]any{"ip": "::2:0:0"}, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("1.1.1.1"), - DBFile: "MaxMind-DB-test-ipv4-24.mmdb", - ExpectedCIDR: "1.1.1.1/32", - ExpectedRecord: map[string]any{"ip": "1.1.1.1"}, - ExpectedOK: true, + IP: netip.MustParseAddr("1.1.1.1"), + DBFile: "MaxMind-DB-test-ipv4-24.mmdb", + ExpectedNetwork: "1.1.1.1/32", + ExpectedRecord: map[string]any{"ip": "1.1.1.1"}, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("1.1.1.3"), - DBFile: "MaxMind-DB-test-ipv4-24.mmdb", - ExpectedCIDR: "1.1.1.2/31", - ExpectedRecord: map[string]any{"ip": "1.1.1.2"}, - ExpectedOK: true, + IP: netip.MustParseAddr("1.1.1.3"), + DBFile: "MaxMind-DB-test-ipv4-24.mmdb", + ExpectedNetwork: "1.1.1.2/31", + ExpectedRecord: map[string]any{"ip": "1.1.1.2"}, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("1.1.1.3"), - DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "1.1.1.0/24", - ExpectedRecord: decoderRecord, - ExpectedOK: true, + IP: netip.MustParseAddr("1.1.1.3"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedNetwork: "1.1.1.0/24", + ExpectedRecord: decoderRecord, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("::ffff:1.1.1.128"), - DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "::ffff:1.1.1.0/120", - ExpectedRecord: decoderRecord, - ExpectedOK: true, + IP: netip.MustParseAddr("::ffff:1.1.1.128"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedNetwork: "::ffff:1.1.1.0/120", + ExpectedRecord: decoderRecord, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("::1.1.1.128"), - DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "::101:100/120", - ExpectedRecord: decoderRecord, - ExpectedOK: true, + IP: netip.MustParseAddr("::1.1.1.128"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedNetwork: "::101:100/120", + ExpectedRecord: decoderRecord, + ExpectedFound: true, }, { - IP: netip.MustParseAddr("200.0.2.1"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "::/64", - ExpectedRecord: "::0/64", - ExpectedOK: true, + IP: netip.MustParseAddr("200.0.2.1"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "::/64", + ExpectedRecord: "::0/64", + ExpectedFound: true, }, { - IP: netip.MustParseAddr("::200.0.2.1"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "::/64", - ExpectedRecord: "::0/64", - ExpectedOK: true, + IP: netip.MustParseAddr("::200.0.2.1"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "::/64", + ExpectedRecord: "::0/64", + ExpectedFound: true, }, { - IP: netip.MustParseAddr("0:0:0:0:ffff:ffff:ffff:ffff"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "::/64", - ExpectedRecord: "::0/64", - ExpectedOK: true, + IP: netip.MustParseAddr("0:0:0:0:ffff:ffff:ffff:ffff"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "::/64", + ExpectedRecord: "::0/64", + ExpectedFound: true, }, { - IP: netip.MustParseAddr("ef00::"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "8000::/1", - ExpectedRecord: nil, - ExpectedOK: false, + IP: netip.MustParseAddr("ef00::"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "8000::/1", + ExpectedRecord: nil, + ExpectedFound: false, }, } @@ -198,10 +198,12 @@ func TestLookupNetwork(t *testing.T) { reader, err := Open(testFile(test.DBFile)) require.NoError(t, err) - network, ok, err := reader.LookupNetwork(test.IP, &record) - require.NoError(t, err) - assert.Equal(t, test.ExpectedOK, ok) - assert.Equal(t, test.ExpectedCIDR, network.String()) + result := reader.Lookup(test.IP) + require.NoError(t, result.Err()) + assert.Equal(t, test.ExpectedFound, result.Found()) + assert.Equal(t, test.ExpectedNetwork, result.Network().String()) + + require.NoError(t, result.Decode(&record)) assert.Equal(t, test.ExpectedRecord, record) }) } @@ -819,21 +821,23 @@ func BenchmarkInterfaceLookup(b *testing.B) { require.NoError(b, db.Close(), "error on close") } -func BenchmarkInterfaceLookupNetwork(b *testing.B) { +func BenchmarkLookupNetwork(b *testing.B) { db, err := Open("GeoLite2-City.mmdb") require.NoError(b, err) //nolint:gosec // this is a test r := rand.New(rand.NewSource(time.Now().UnixNano())) - var result any s := make(net.IP, 4) for i := 0; i < b.N; i++ { ip := randomIPv4Address(r, s) - _, _, err = db.LookupNetwork(ip, &result) - if err != nil { + res := db.Lookup(ip) + if err := res.Err(); err != nil { b.Error(err) } + if !res.Network().IsValid() { + b.Fatalf("invalid network for %s", ip) + } } require.NoError(b, db.Close(), "error on close") } @@ -907,19 +911,18 @@ func BenchmarkCityLookup(b *testing.B) { require.NoError(b, db.Close(), "error on close") } -func BenchmarkCityLookupNetwork(b *testing.B) { +func BenchmarkCityLookupOnly(b *testing.B) { db, err := Open("GeoLite2-City.mmdb") require.NoError(b, err) //nolint:gosec // this is a test r := rand.New(rand.NewSource(time.Now().UnixNano())) - var result fullCity s := make(net.IP, 4) for i := 0; i < b.N; i++ { ip := randomIPv4Address(r, s) - _, _, err = db.LookupNetwork(ip, &result) - if err != nil { + result := db.Lookup(ip) + if err := result.Err(); err != nil { b.Error(err) } } diff --git a/result.go b/result.go index ba66f68..4375289 100644 --- a/result.go +++ b/result.go @@ -3,15 +3,18 @@ package maxminddb import ( "errors" "math" + "net/netip" "reflect" ) const notFound uint = math.MaxUint type Result struct { - err error - decoder decoder - offset uint + ip netip.Addr + err error + decoder decoder + offset uint + prefixLen uint8 } // Decode unmarshals the data from the data section into the value pointed to @@ -103,3 +106,27 @@ func (r Result) Err() error { func (r Result) Found() bool { return r.err == nil && r.offset != notFound } + +// Network returns the netip.Prefix representing the network associated with +// the data record in the database. +func (r Result) Network() netip.Prefix { + ip := r.ip + prefixLen := int(r.prefixLen) + + if ip.Is4() { + // This is necessary as the node that the IPv4 start is at may + // be at a bit depth that is less that 96, i.e., ipv4Start points + // to a leaf node. For instance, if a record was inserted at ::/8, + // the ipv4Start would point directly at the leaf node for the + // record and would have a bit depth of 8. This would not happen + // with databases currently distributed by MaxMind as all of them + // have an IPv4 subtree that is greater than a single node. + if prefixLen < 96 { + return netip.PrefixFrom(zeroIP, prefixLen) + } + prefixLen -= 96 + } + + prefix, _ := ip.Prefix(prefixLen) + return prefix +} From 0a0e493d820c25ea2301a6dc0eea2bf61f30a06b Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 7 Jul 2024 13:32:57 -0700 Subject: [PATCH 07/16] Replace LookupOffset with RecordOffset --- reader.go | 28 +++------------------------- reader_test.go | 35 ++++++++++++++++++----------------- result.go | 13 ++++++++++++- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/reader.go b/reader.go index 6893b36..c26f28d 100644 --- a/reader.go +++ b/reader.go @@ -9,13 +9,7 @@ import ( "reflect" ) -const ( - // NotFound is returned by LookupOffset when a matched root record offset - // cannot be found. - NotFound = ^uintptr(0) - - dataSectionSeparatorSize = 16 -) +const dataSectionSeparatorSize = 16 var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com") @@ -156,24 +150,6 @@ func (r *Reader) Lookup(ip netip.Addr) Result { } } -// LookupOffset maps an argument net.IP to a corresponding record offset in the -// database. NotFound is returned if no such record is found, and a record may -// otherwise be extracted by passing the returned offset to Decode. LookupOffset -// is an advanced API, which exists to provide clients with a means to cache -// previously-decoded records. -func (r *Reader) LookupOffset(ip netip.Addr) (uintptr, error) { - if r.buffer == nil { - return 0, errors.New("cannot call LookupOffset on a closed database") - } - pointer, _, err := r.lookupPointer(ip) - if pointer == 0 || err != nil { - return NotFound, err - } - return r.resolveDataPointer(pointer) -} - -var zeroIP = netip.MustParseAddr("::") - // Decode the record at |offset| into |result|. The result value pointed to // must be a data value that corresponds to a record in the database. This may // include a struct representation of the data, a map capable of holding the @@ -198,6 +174,8 @@ func (r *Reader) Decode(offset uintptr, result any) error { return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) } +var zeroIP = netip.MustParseAddr("::") + func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { if r.Metadata.IPVersion == 4 && ip.Is6() { return 0, 0, fmt.Errorf( diff --git a/reader_test.go b/reader_test.go index ea48b5f..3976b84 100644 --- a/reader_test.go +++ b/reader_test.go @@ -300,19 +300,19 @@ func TestDecoder(t *testing.T) { { // Directly lookup and decode. - var result TestType - require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)) - verify(result) + var testV TestType + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&testV)) + verify(testV) } { // Lookup record offset, then Decode. - var result TestType - offset, err := reader.LookupOffset(netip.MustParseAddr("::1.1.1.0")) - require.NoError(t, err) - assert.NotEqual(t, NotFound, offset) + var testV TestType + result := reader.Lookup(netip.MustParseAddr("::1.1.1.0")) + require.NoError(t, result.Err()) + require.True(t, result.Found()) - require.NoError(t, reader.Decode(offset, &result)) - verify(result) + require.NoError(t, reader.Decode(result.RecordOffset(), &testV)) + verify(testV) } require.NoError(t, reader.Close()) @@ -548,9 +548,9 @@ func TestNestedOffsetDecode(t *testing.T) { db, err := Open(testFile("GeoIP2-City-Test.mmdb")) require.NoError(t, err) - off, err := db.LookupOffset(netip.MustParseAddr("81.2.69.142")) - assert.NotEqual(t, NotFound, off) - require.NoError(t, err) + result := db.Lookup(netip.MustParseAddr("81.2.69.142")) + require.NoError(t, result.Err()) + require.True(t, result.Found()) var root struct { CountryOffset uintptr `maxminddb:"country"` @@ -563,7 +563,7 @@ func TestNestedOffsetDecode(t *testing.T) { TimeZoneOffset uintptr `maxminddb:"time_zone"` } `maxminddb:"location"` } - require.NoError(t, db.Decode(off, &root)) + require.NoError(t, db.Decode(result.RecordOffset(), &root)) assert.InEpsilon(t, 51.5142, root.Location.Latitude, 1e-10) var longitude float64 @@ -671,14 +671,15 @@ func TestUsingClosedDatabase(t *testing.T) { require.NoError(t, err) require.NoError(t, reader.Close()) - var recordInterface any addr := netip.MustParseAddr("::") + + result := reader.Lookup(addr) + assert.Equal(t, "cannot call Lookup on a closed database", result.Err().Error()) + + var recordInterface any err = reader.Lookup(addr).Decode(recordInterface) assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) - _, err = reader.LookupOffset(addr) - assert.Equal(t, "cannot call LookupOffset on a closed database", err.Error()) - err = reader.Decode(0, recordInterface) assert.Equal(t, "cannot call Decode on a closed database", err.Error()) } diff --git a/result.go b/result.go index 4375289..0d142d2 100644 --- a/result.go +++ b/result.go @@ -95,7 +95,7 @@ func (r Result) DecodePath(v any, path ...any) error { } // Err provides a way to check whether there was an error during the lookup -// without clling Result.Decode. If there was an error, it will also be +// without calling Result.Decode. If there was an error, it will also be // returned from Result.Decode. func (r Result) Err() error { return r.err @@ -107,6 +107,17 @@ func (r Result) Found() bool { return r.err == nil && r.offset != notFound } +// RecordOffset returns the offset of the record in the database. This can be +// passed to ReaderDecode. It can also be used as a unique identifier for the +// data record in the particular database to cache the data record across +// lookups. Note that while the offset uniquely identifies the data record, +// other data in Result may differ between lookups. The offset is only valid +// for the current database version. If you update the database file, you must +// invalidate any cache associated with the previous version. +func (r Result) RecordOffset() uintptr { + return uintptr(r.offset) +} + // Network returns the netip.Prefix representing the network associated with // the data record in the database. func (r Result) Network() netip.Prefix { From 62022ad5533134170d74b4dce18f2494f2977b61 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Tue, 9 Jul 2024 20:11:27 -0700 Subject: [PATCH 08/16] Switch Networks methods to iterators --- .github/workflows/go.yml | 2 +- example_test.go | 21 +++------- go.mod | 2 +- reader.go | 8 ---- traverse.go | 88 +++++++++++++++++++++------------------- traverse_test.go | 35 ++++++++-------- verifier.go | 11 ++--- 7 files changed, 73 insertions(+), 94 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0481be5..b45f167 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,7 @@ jobs: name: Build strategy: matrix: - go-version: [1.21.x, 1.22.x] + go-version: [1.23.0-rc.1] platform: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.platform }} steps: diff --git a/example_test.go b/example_test.go index dfad270..1a19df4 100644 --- a/example_test.go +++ b/example_test.go @@ -63,20 +63,16 @@ func ExampleReader_Networks() { } defer db.Close() - networks := db.Networks() - for networks.Next() { + for result := range db.Networks() { record := struct { Domain string `maxminddb:"connection_type"` }{} - subnet, err := networks.Network(&record) + err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", subnet.String(), record.Domain) - } - if networks.Err() != nil { - log.Panic(networks.Err()) + fmt.Printf("%s: %s\n", result.Network(), record.Domain) } // Output: // 1.0.0.0/24: Cable/DSL @@ -119,20 +115,15 @@ func ExampleReader_NetworksWithin() { log.Panic(err) } - networks := db.NetworksWithin(prefix) - for networks.Next() { + for result := range db.NetworksWithin(prefix) { record := struct { Domain string `maxminddb:"connection_type"` }{} - - subnet, err := networks.Network(&record) + err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", subnet.String(), record.Domain) - } - if networks.Err() != nil { - log.Panic(networks.Err()) + fmt.Printf("%s: %s\n", result.Network(), record.Domain) } // Output: diff --git a/go.mod b/go.mod index 11b1b10..fc7d1c9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/oschwald/maxminddb-golang/v2 -go 1.21 +go 1.23 require ( github.com/stretchr/testify v1.9.0 diff --git a/reader.go b/reader.go index c26f28d..e9f76c3 100644 --- a/reader.go +++ b/reader.go @@ -221,14 +221,6 @@ func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) return node, i } -func (r *Reader) retrieveData(pointer uint, result any) error { - offset, err := r.resolveDataPointer(pointer) - if err != nil { - return err - } - return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) -} - func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize) diff --git a/traverse.go b/traverse.go index 67a5148..9054c21 100644 --- a/traverse.go +++ b/traverse.go @@ -3,6 +3,8 @@ package maxminddb import ( "fmt" "net/netip" + + "iter" ) // Internal structure used to keep track of nodes we still need to visit. @@ -12,8 +14,8 @@ type netNode struct { pointer uint } -// Networks represents a set of subnets that we are iterating over. -type Networks struct { +// networks represents a set of subnets that we are iterating over. +type networks struct { err error reader *Reader nodes []netNode @@ -27,12 +29,12 @@ var ( ) // NetworksOption are options for Networks and NetworksWithin. -type NetworksOption func(*Networks) +type NetworksOption func(*networks) // IncludeAliasedNetworks is an option for Networks and NetworksWithin // that makes them iterate over aliases of the IPv4 subtree in an IPv6 // database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. -func IncludeAliasedNetworks(networks *Networks) { +func IncludeAliasedNetworks(networks *networks) { networks.includeAliasedNetworks = true } @@ -43,15 +45,11 @@ func IncludeAliasedNetworks(networks *Networks) { // in an IPv6 database. This iterator will only iterate over these once by // default. To iterate over all the IPv4 network locations, use the // IncludeAliasedNetworks option. -func (r *Reader) Networks(options ...NetworksOption) *Networks { - var networks *Networks +func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] { if r.Metadata.IPVersion == 6 { - networks = r.NetworksWithin(allIPv6, options...) - } else { - networks = r.NetworksWithin(allIPv4, options...) + return r.NetworksWithin(allIPv6, options...) } - - return networks + return r.NetworksWithin(allIPv4, options...) } // NetworksWithin returns an iterator that can be used to traverse all networks @@ -64,9 +62,41 @@ func (r *Reader) Networks(options ...NetworksOption) *Networks { // // If the provided prefix is contained within a network in the database, the // iterator will iterate over exactly one network, the containing network. -func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) *Networks { +func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] { + n := r.networksWithin(prefix, options...) + return func(yield func(Result) bool) { + for n.next() { + if n.err != nil { + yield(Result{err: n.err}) + return + } + + ip := n.lastNode.ip + if isInIPv4Subtree(ip) { + ip = v6ToV4(ip) + } + + offset, err := r.resolveDataPointer(n.lastNode.pointer) + ok := yield(Result{ + decoder: r.decoder, + ip: ip, + offset: uint(offset), + prefixLen: uint8(n.lastNode.bit), + err: err, + }) + if !ok { + return + } + } + if n.err != nil { + yield(Result{err: n.err}) + } + } +} + +func (r *Reader) networksWithin(prefix netip.Prefix, options ...NetworksOption) *networks { if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { - return &Networks{ + return &networks{ err: fmt.Errorf( "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", prefix, @@ -74,7 +104,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) } } - networks := &Networks{reader: r} + networks := &networks{reader: r} for _, option := range options { option(networks) } @@ -105,10 +135,10 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) return networks } -// Next prepares the next network for reading with the Network method. It +// next prepares the next network for reading with the Network method. It // returns true if there is another network to be processed and false if there // are no more networks or if there is an error. -func (n *Networks) Next() bool { +func (n *networks) next() bool { if n.err != nil { return false } @@ -160,32 +190,6 @@ func (n *Networks) Next() bool { return false } -// Network returns the current network or an error if there is a problem -// decoding the data for the network. It takes a pointer to a result value to -// decode the network's data into. -func (n *Networks) Network(result any) (netip.Prefix, error) { - if n.err != nil { - return netip.Prefix{}, n.err - } - if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { - return netip.Prefix{}, err - } - - ip := n.lastNode.ip - prefixLength := int(n.lastNode.bit) - if isInIPv4Subtree(ip) { - ip = v6ToV4(ip) - prefixLength -= 96 - } - - return netip.PrefixFrom(ip, prefixLength), nil -} - -// Err returns an error, if any, that was encountered during iteration. -func (n *Networks) Err() error { - return n.err -} - var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next() // isInIPv4Subtree returns true if the IP is in the database's IPv4 subtree. diff --git a/traverse_test.go b/traverse_test.go index 382659a..df4bcc8 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -20,18 +20,18 @@ func TestNetworks(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - n := reader.Networks() - for n.Next() { + for result := range reader.Networks() { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) + + network := result.Network() assert.Equal(t, record.IP, network.Addr().String(), "expected %s got %s", record.IP, network.Addr().String(), ) } - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) } } @@ -41,13 +41,14 @@ func TestNetworksWithInvalidSearchTree(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-broken-search-tree-24.mmdb")) require.NoError(t, err, "unexpected error while opening database: %v", err) - n := reader.Networks() - for n.Next() { + for result := range reader.Networks() { var record any - _, err := n.Network(&record) - require.NoError(t, err) + err = result.Decode(&record) + if err != nil { + break + } } - require.EqualError(t, n.Err(), "invalid search tree at 128.128.128.128/32") + require.EqualError(t, err, "invalid search tree at 128.128.128.128/32") require.NoError(t, reader.Close()) } @@ -285,20 +286,18 @@ func TestNetworksWithin(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - n := reader.NetworksWithin(network, v.Options...) var innerIPs []string - for n.Next() { + for result := range reader.NetworksWithin(network, v.Options...) { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) + innerIPs = append(innerIPs, result.Network().String()) } assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) }) @@ -326,20 +325,18 @@ func TestGeoIPNetworksWithin(t *testing.T) { prefix, err := netip.ParsePrefix(v.Network) require.NoError(t, err) - n := reader.NetworksWithin(prefix) var innerIPs []string - for n.Next() { + for result := range reader.NetworksWithin(prefix) { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) + innerIPs = append(innerIPs, result.Network().String()) } assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) } diff --git a/verifier.go b/verifier.go index b14b3e4..335cb1b 100644 --- a/verifier.go +++ b/verifier.go @@ -102,16 +102,11 @@ func (v *verifier) verifyDatabase() error { func (v *verifier) verifySearchTree() (map[uint]bool, error) { offsets := make(map[uint]bool) - it := v.reader.Networks() - for it.Next() { - offset, err := v.reader.resolveDataPointer(it.lastNode.pointer) - if err != nil { + for result := range v.reader.Networks() { + if err := result.Err(); err != nil { return nil, err } - offsets[uint(offset)] = true - } - if err := it.Err(); err != nil { - return nil, err + offsets[result.offset] = true } return offsets, nil } From 9541b7117d087dd1f21b12b38181268274e8a7e8 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 14 Jul 2024 13:24:02 -0700 Subject: [PATCH 09/16] Inline traversal functions In order to make it easier to refactor and optimize. --- traverse.go | 205 +++++++++++++++++++++++----------------------------- 1 file changed, 90 insertions(+), 115 deletions(-) diff --git a/traverse.go b/traverse.go index 9054c21..a81ed37 100644 --- a/traverse.go +++ b/traverse.go @@ -14,12 +14,7 @@ type netNode struct { pointer uint } -// networks represents a set of subnets that we are iterating over. -type networks struct { - err error - reader *Reader - nodes []netNode - lastNode netNode +type networkOptions struct { includeAliasedNetworks bool } @@ -29,12 +24,12 @@ var ( ) // NetworksOption are options for Networks and NetworksWithin. -type NetworksOption func(*networks) +type NetworksOption func(*networkOptions) // IncludeAliasedNetworks is an option for Networks and NetworksWithin // that makes them iterate over aliases of the IPv4 subtree in an IPv6 // database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. -func IncludeAliasedNetworks(networks *networks) { +func IncludeAliasedNetworks(networks *networkOptions) { networks.includeAliasedNetworks = true } @@ -63,131 +58,111 @@ func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] { // If the provided prefix is contained within a network in the database, the // iterator will iterate over exactly one network, the containing network. func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] { - n := r.networksWithin(prefix, options...) return func(yield func(Result) bool) { - for n.next() { - if n.err != nil { - yield(Result{err: n.err}) - return - } - - ip := n.lastNode.ip - if isInIPv4Subtree(ip) { - ip = v6ToV4(ip) - } - - offset, err := r.resolveDataPointer(n.lastNode.pointer) - ok := yield(Result{ - decoder: r.decoder, - ip: ip, - offset: uint(offset), - prefixLen: uint8(n.lastNode.bit), - err: err, + if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { + yield(Result{ + err: fmt.Errorf( + "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", + prefix, + ), }) - if !ok { - return - } - } - if n.err != nil { - yield(Result{err: n.err}) + return } - } -} -func (r *Reader) networksWithin(prefix netip.Prefix, options ...NetworksOption) *networks { - if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { - return &networks{ - err: fmt.Errorf( - "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", - prefix, - ), + n := &networkOptions{} + for _, option := range options { + option(n) } - } - - networks := &networks{reader: r} - for _, option := range options { - option(networks) - } - ip := prefix.Addr() - netIP := ip - stopBit := prefix.Bits() - if ip.Is4() { - netIP = v4ToV16(ip) - stopBit += 96 - } - - pointer, bit := r.traverseTree(ip, 0, stopBit) + ip := prefix.Addr() + netIP := ip + stopBit := prefix.Bits() + if ip.Is4() { + netIP = v4ToV16(ip) + stopBit += 96 + } - prefix, err := netIP.Prefix(bit) - if err != nil { - networks.err = fmt.Errorf("prefixing %s with %d", netIP, bit) - } + pointer, bit := r.traverseTree(ip, 0, stopBit) - networks.nodes = []netNode{ - { - ip: prefix.Addr(), - bit: uint(bit), - pointer: pointer, - }, - } + prefix, err := netIP.Prefix(bit) + if err != nil { + yield(Result{ + err: fmt.Errorf("prefixing %s with %d", netIP, bit), + }) + } - return networks -} + nodes := []netNode{ + { + ip: prefix.Addr(), + bit: uint(bit), + pointer: pointer, + }, + } -// next prepares the next network for reading with the Network method. It -// returns true if there is another network to be processed and false if there -// are no more networks or if there is an error. -func (n *networks) next() bool { - if n.err != nil { - return false - } - for len(n.nodes) > 0 { - node := n.nodes[len(n.nodes)-1] - n.nodes = n.nodes[:len(n.nodes)-1] - - for node.pointer != n.reader.Metadata.NodeCount { - // This skips IPv4 aliases without hardcoding the networks that the writer - // currently aliases. - if !n.includeAliasedNetworks && n.reader.ipv4Start != 0 && - node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) { - break - } + for len(nodes) > 0 { + node := nodes[len(nodes)-1] + nodes = nodes[:len(nodes)-1] - if node.pointer > n.reader.Metadata.NodeCount { - n.lastNode = node - return true - } - ipRight := node.ip.As16() - if len(ipRight) <= int(node.bit>>3) { - displayAddr := node.ip - displayBits := node.bit - if isInIPv4Subtree(node.ip) { - displayAddr = v6ToV4(displayAddr) - displayBits -= 96 + for node.pointer != r.Metadata.NodeCount { + // This skips IPv4 aliases without hardcoding the networks that the writer + // currently aliases. + if !n.includeAliasedNetworks && r.ipv4Start != 0 && + node.pointer == r.ipv4Start && !isInIPv4Subtree(node.ip) { + break } - n.err = newInvalidDatabaseError( - "invalid search tree at %s/%d", displayAddr, displayBits) - return false - } - ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) + if node.pointer > r.Metadata.NodeCount { + ip := node.ip + if isInIPv4Subtree(ip) { + ip = v6ToV4(ip) + } + + offset, err := r.resolveDataPointer(node.pointer) + ok := yield(Result{ + decoder: r.decoder, + ip: ip, + offset: uint(offset), + prefixLen: uint8(node.bit), + err: err, + }) + if !ok { + return + } + break + } + ipRight := node.ip.As16() + if len(ipRight) <= int(node.bit>>3) { + displayAddr := node.ip + displayBits := node.bit + if isInIPv4Subtree(node.ip) { + displayAddr = v6ToV4(displayAddr) + displayBits -= 96 + } + + yield(Result{ + ip: displayAddr, + prefixLen: uint8(node.bit), + err: newInvalidDatabaseError( + "invalid search tree at %s/%d", displayAddr, displayBits), + }) + return + } + ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) - offset := node.pointer * n.reader.nodeOffsetMult - rightPointer := n.reader.nodeReader.readRight(offset) + offset := node.pointer * r.nodeOffsetMult + rightPointer := r.nodeReader.readRight(offset) - node.bit++ - n.nodes = append(n.nodes, netNode{ - pointer: rightPointer, - ip: netip.AddrFrom16(ipRight), - bit: node.bit, - }) + node.bit++ + nodes = append(nodes, netNode{ + pointer: rightPointer, + ip: netip.AddrFrom16(ipRight), + bit: node.bit, + }) - node.pointer = n.reader.nodeReader.readLeft(offset) + node.pointer = r.nodeReader.readLeft(offset) + } } } - - return false } var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next() From 3c14af0def1b8d6fdd37c1a2656ecfed943c9999 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 14 Jul 2024 13:51:56 -0700 Subject: [PATCH 10/16] Add benchmarks for Networks --- traverse_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/traverse_test.go b/traverse_test.go index df4bcc8..554c2b2 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -341,3 +341,19 @@ func TestGeoIPNetworksWithin(t *testing.T) { require.NoError(t, reader.Close()) } } + +func BenchmarkNetworks(b *testing.B) { + db, err := Open(testFile("GeoIP2-Country-Test.mmdb")) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + for r := range db.Networks() { + var rec struct{} + err = r.Decode(&rec) + if err != nil { + b.Error(err) + } + } + } + require.NoError(b, db.Close(), "error on close") +} From 21b9508ef82c56ff7a648ca834cd114c8070b288 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 14 Jul 2024 13:59:52 -0700 Subject: [PATCH 11/16] Simplify error handling --- traverse.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/traverse.go b/traverse.go index a81ed37..26cd9d9 100644 --- a/traverse.go +++ b/traverse.go @@ -87,7 +87,9 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) prefix, err := netIP.Prefix(bit) if err != nil { yield(Result{ - err: fmt.Errorf("prefixing %s with %d", netIP, bit), + ip: ip, + prefixLen: uint8(bit), + err: fmt.Errorf("prefixing %s with %d", netIP, bit), }) } @@ -133,18 +135,19 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) ipRight := node.ip.As16() if len(ipRight) <= int(node.bit>>3) { displayAddr := node.ip - displayBits := node.bit if isInIPv4Subtree(node.ip) { displayAddr = v6ToV4(displayAddr) - displayBits -= 96 } - yield(Result{ + res := Result{ ip: displayAddr, prefixLen: uint8(node.bit), - err: newInvalidDatabaseError( - "invalid search tree at %s/%d", displayAddr, displayBits), - }) + } + res.err = newInvalidDatabaseError( + "invalid search tree at %s", res.Network()) + + yield(res) + return } ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) From 13ed83bef4f89c2a495333505a0c41ca8e4d19ab Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 14 Jul 2024 14:04:04 -0700 Subject: [PATCH 12/16] Reduce allocations in Networks Slice length of 64 is used as most IPv6 trees will have that as their maximum depth. --- traverse.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/traverse.go b/traverse.go index 26cd9d9..4ad60df 100644 --- a/traverse.go +++ b/traverse.go @@ -93,13 +93,14 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) }) } - nodes := []netNode{ - { + nodes := make([]netNode, 0, 64) + nodes = append(nodes, + netNode{ ip: prefix.Addr(), bit: uint(bit), pointer: pointer, }, - } + ) for len(nodes) > 0 { node := nodes[len(nodes)-1] From a584c6104921f798c25ff020e566d420b0eaba73 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Wed, 14 Aug 2024 19:50:24 -0700 Subject: [PATCH 13/16] Fix lints with new golangci-lint The comment on "iter" is because the linter was not stable and kept on shifting it every time it ran. --- decoder.go | 4 ++-- traverse.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/decoder.go b/decoder.go index db941fc..5864873 100644 --- a/decoder.go +++ b/decoder.go @@ -30,8 +30,8 @@ const ( _Slice // We don't use the next two. They are placeholders. See the spec // for more details. - _Container //nolint: deadcode, varcheck // above - _Marker //nolint: deadcode, varcheck // above + _Container //nolint:deadcode,varcheck // above + _Marker //nolint:deadcode,varcheck // above _Bool _Float32 ) diff --git a/traverse.go b/traverse.go index 4ad60df..9e4dc3d 100644 --- a/traverse.go +++ b/traverse.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" + // comment to prevent gofumpt from randomly moving iter. "iter" ) From c075ec6c8ab1bb4a274dc2ef59bb19016737d4ea Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 18 Aug 2024 16:09:18 -0700 Subject: [PATCH 14/16] Remove (*Reader).Decode And replace it with (*Reader).LookupOffset --- reader.go | 24 +++++------------------- reader_test.go | 17 +++++++++++------ 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/reader.go b/reader.go index e9f76c3..c52dedc 100644 --- a/reader.go +++ b/reader.go @@ -150,28 +150,14 @@ func (r *Reader) Lookup(ip netip.Addr) Result { } } -// Decode the record at |offset| into |result|. The result value pointed to -// must be a data value that corresponds to a record in the database. This may -// include a struct representation of the data, a map capable of holding the -// data or an empty any value. -// -// If result is a pointer to a struct, the struct need not include a field -// for every value that may be in the database. If a field is not present in -// the structure, the decoder will not decode that field, reducing the time -// required to decode the record. -// -// As a special case, a struct field of type uintptr will be used to capture -// the offset of the value. Decode may later be used to extract the stored -// value from the offset. MaxMind DBs are highly normalized: for example in -// the City database, all records of the same country will reference a -// single representative record for that country. This uintptr behavior allows -// clients to leverage this normalization in their own sub-record caching. -func (r *Reader) Decode(offset uintptr, result any) error { +// LookupOffset returns the Result for the specified offset. Note that +// netip.Prefix returned by Networks will be invalid when using LookupOffset. +func (r *Reader) LookupOffset(offset uintptr) Result { if r.buffer == nil { - return errors.New("cannot call Decode on a closed database") + return Result{err: errors.New("cannot call Decode on a closed database")} } - return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) + return Result{decoder: r.decoder, offset: uint(offset)} } var zeroIP = netip.MustParseAddr("::") diff --git a/reader_test.go b/reader_test.go index 3976b84..968bbdd 100644 --- a/reader_test.go +++ b/reader_test.go @@ -311,7 +311,8 @@ func TestDecoder(t *testing.T) { require.NoError(t, result.Err()) require.True(t, result.Found()) - require.NoError(t, reader.Decode(result.RecordOffset(), &testV)) + res := reader.LookupOffset(result.RecordOffset()) + require.NoError(t, res.Decode(&testV)) verify(testV) } @@ -563,21 +564,25 @@ func TestNestedOffsetDecode(t *testing.T) { TimeZoneOffset uintptr `maxminddb:"time_zone"` } `maxminddb:"location"` } - require.NoError(t, db.Decode(result.RecordOffset(), &root)) + res := db.LookupOffset(result.RecordOffset()) + require.NoError(t, res.Decode(&root)) assert.InEpsilon(t, 51.5142, root.Location.Latitude, 1e-10) var longitude float64 - require.NoError(t, db.Decode(root.Location.LongitudeOffset, &longitude)) + res = db.LookupOffset(root.Location.LongitudeOffset) + require.NoError(t, res.Decode(&longitude)) assert.InEpsilon(t, -0.0931, longitude, 1e-10) var timeZone string - require.NoError(t, db.Decode(root.Location.TimeZoneOffset, &timeZone)) + res = db.LookupOffset(root.Location.TimeZoneOffset) + require.NoError(t, res.Decode(&timeZone)) assert.Equal(t, "Europe/London", timeZone) var country struct { IsoCode string `maxminddb:"iso_code"` } - require.NoError(t, db.Decode(root.CountryOffset, &country)) + res = db.LookupOffset(root.CountryOffset) + require.NoError(t, res.Decode(&country)) assert.Equal(t, "GB", country.IsoCode) require.NoError(t, db.Close()) @@ -680,7 +685,7 @@ func TestUsingClosedDatabase(t *testing.T) { err = reader.Lookup(addr).Decode(recordInterface) assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) - err = reader.Decode(0, recordInterface) + err = reader.LookupOffset(0).Decode(recordInterface) assert.Equal(t, "cannot call Decode on a closed database", err.Error()) } From 791db996c8497316889c7a8a0e6030b4bf0dba2e Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 18 Aug 2024 16:14:55 -0700 Subject: [PATCH 15/16] Rename RecordOffset to Offset --- reader_test.go | 4 ++-- result.go | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/reader_test.go b/reader_test.go index 968bbdd..4b4bb53 100644 --- a/reader_test.go +++ b/reader_test.go @@ -311,7 +311,7 @@ func TestDecoder(t *testing.T) { require.NoError(t, result.Err()) require.True(t, result.Found()) - res := reader.LookupOffset(result.RecordOffset()) + res := reader.LookupOffset(result.Offset()) require.NoError(t, res.Decode(&testV)) verify(testV) } @@ -564,7 +564,7 @@ func TestNestedOffsetDecode(t *testing.T) { TimeZoneOffset uintptr `maxminddb:"time_zone"` } `maxminddb:"location"` } - res := db.LookupOffset(result.RecordOffset()) + res := db.LookupOffset(result.Offset()) require.NoError(t, res.Decode(&root)) assert.InEpsilon(t, 51.5142, root.Location.Latitude, 1e-10) diff --git a/result.go b/result.go index 0d142d2..c06540a 100644 --- a/result.go +++ b/result.go @@ -107,14 +107,14 @@ func (r Result) Found() bool { return r.err == nil && r.offset != notFound } -// RecordOffset returns the offset of the record in the database. This can be -// passed to ReaderDecode. It can also be used as a unique identifier for the -// data record in the particular database to cache the data record across -// lookups. Note that while the offset uniquely identifies the data record, -// other data in Result may differ between lookups. The offset is only valid -// for the current database version. If you update the database file, you must -// invalidate any cache associated with the previous version. -func (r Result) RecordOffset() uintptr { +// Offset returns the offset of the record in the database. This can be +// passed to (*Reader).LookupOffset. It can also be used as a unique +// identifier for the data record in the particular database to cache the data +// record across lookups. Note that while the offset uniquely identifies the +// data record, other data in Result may differ between lookups. The offset +// is only valid for the current database version. If you update the database +// file, you must invalidate any cache associated with the previous version. +func (r Result) Offset() uintptr { return uintptr(r.offset) } From cb27d1e7784af1c313a8c83ccce633da492cbbc3 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sun, 18 Aug 2024 16:16:50 -0700 Subject: [PATCH 16/16] Rename Network to Prefix To more closely match net/netip. Also, it may reduce confusion with Networks and NetworksWithin, which refer to more than just the network. --- example_test.go | 4 ++-- reader_test.go | 4 ++-- result.go | 4 ++-- traverse.go | 2 +- traverse_test.go | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/example_test.go b/example_test.go index 1a19df4..0f15c1e 100644 --- a/example_test.go +++ b/example_test.go @@ -72,7 +72,7 @@ func ExampleReader_Networks() { if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", result.Network(), record.Domain) + fmt.Printf("%s: %s\n", result.Prefix(), record.Domain) } // Output: // 1.0.0.0/24: Cable/DSL @@ -123,7 +123,7 @@ func ExampleReader_NetworksWithin() { if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", result.Network(), record.Domain) + fmt.Printf("%s: %s\n", result.Prefix(), record.Domain) } // Output: diff --git a/reader_test.go b/reader_test.go index 4b4bb53..1908635 100644 --- a/reader_test.go +++ b/reader_test.go @@ -201,7 +201,7 @@ func TestLookupNetwork(t *testing.T) { result := reader.Lookup(test.IP) require.NoError(t, result.Err()) assert.Equal(t, test.ExpectedFound, result.Found()) - assert.Equal(t, test.ExpectedNetwork, result.Network().String()) + assert.Equal(t, test.ExpectedNetwork, result.Prefix().String()) require.NoError(t, result.Decode(&record)) assert.Equal(t, test.ExpectedRecord, record) @@ -841,7 +841,7 @@ func BenchmarkLookupNetwork(b *testing.B) { if err := res.Err(); err != nil { b.Error(err) } - if !res.Network().IsValid() { + if !res.Prefix().IsValid() { b.Fatalf("invalid network for %s", ip) } } diff --git a/result.go b/result.go index c06540a..e50bb80 100644 --- a/result.go +++ b/result.go @@ -118,9 +118,9 @@ func (r Result) Offset() uintptr { return uintptr(r.offset) } -// Network returns the netip.Prefix representing the network associated with +// Prefix returns the netip.Prefix representing the network associated with // the data record in the database. -func (r Result) Network() netip.Prefix { +func (r Result) Prefix() netip.Prefix { ip := r.ip prefixLen := int(r.prefixLen) diff --git a/traverse.go b/traverse.go index 9e4dc3d..6a748db 100644 --- a/traverse.go +++ b/traverse.go @@ -146,7 +146,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) prefixLen: uint8(node.bit), } res.err = newInvalidDatabaseError( - "invalid search tree at %s", res.Network()) + "invalid search tree at %s", res.Prefix()) yield(res) diff --git a/traverse_test.go b/traverse_test.go index 554c2b2..963d710 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -27,7 +27,7 @@ func TestNetworks(t *testing.T) { err := result.Decode(&record) require.NoError(t, err) - network := result.Network() + network := result.Prefix() assert.Equal(t, record.IP, network.Addr().String(), "expected %s got %s", record.IP, network.Addr().String(), ) @@ -294,7 +294,7 @@ func TestNetworksWithin(t *testing.T) { }{} err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, result.Network().String()) + innerIPs = append(innerIPs, result.Prefix().String()) } assert.Equal(t, v.Expected, innerIPs) @@ -333,7 +333,7 @@ func TestGeoIPNetworksWithin(t *testing.T) { }{} err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, result.Network().String()) + innerIPs = append(innerIPs, result.Prefix().String()) } assert.Equal(t, v.Expected, innerIPs)