diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0481be5..b45f167 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,7 @@ jobs: name: Build strategy: matrix: - go-version: [1.21.x, 1.22.x] + go-version: [1.23.0-rc.1] platform: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.platform }} steps: diff --git a/.golangci.toml b/.golangci.toml index 799416c..0abb5aa 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -18,6 +18,7 @@ disable = [ "forcetypeassert", "funlen", "gochecknoglobals", + "gocognit", "godox", "gomnd", "inamedparam", diff --git a/decoder.go b/decoder.go index 435591e..5864873 100644 --- a/decoder.go +++ b/decoder.go @@ -30,8 +30,8 @@ const ( _Slice // We don't use the next two. They are placeholders. See the spec // for more details. - _Container //nolint: deadcode, varcheck // above - _Marker //nolint: deadcode, varcheck // above + _Container //nolint:deadcode,varcheck // above + _Marker //nolint:deadcode,varcheck // above _Bool _Float32 ) @@ -89,6 +89,82 @@ func (d *decoder) decodeToDeserializer( return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1) } +func (d *decoder) decodePath( + offset uint, + path []any, + result reflect.Value, +) error { +PATH: + for i, v := range path { + var ( + typeNum dataType + size uint + err error + ) + typeNum, size, offset, err = d.decodeCtrlData(offset) + if err != nil { + return err + } + + if typeNum == _Pointer { + pointer, _, err := d.decodePointer(size, offset) + if err != nil { + return err + } + + typeNum, size, offset, err = d.decodeCtrlData(pointer) + if err != nil { + return err + } + } + + switch v := v.(type) { + case string: + // We are expecting a map + if typeNum != _Map { + // XXX - use type names in errors. + return fmt.Errorf("expected a map for %s but found %d", v, typeNum) + } + for i := uint(0); i < size; i++ { + var key []byte + key, offset, err = d.decodeKey(offset) + if err != nil { + return err + } + if string(key) == v { + continue PATH + } + offset, err = d.nextValueOffset(offset, 1) + if err != nil { + return err + } + } + // Not found. Maybe return a boolean? + return nil + case int: + // We are expecting an array + if typeNum != _Slice { + // XXX - use type names in errors. + return fmt.Errorf("expected a slice for %d but found %d", v, typeNum) + } + if size < uint(v) { + // Slice is smaller than index, not found + return nil + } + // TODO: support negative indexes? Seems useful for subdivisions in + // particular. + offset, err = d.nextValueOffset(offset, uint(v)) + if err != nil { + return err + } + default: + return fmt.Errorf("unexpected type for %d value in path, %v: %T", i, v, v) + } + } + _, err := d.decode(offset, result, len(path)) + return err +} + func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) { newOffset := offset + 1 if offset >= uint(len(d.buffer)) { diff --git a/decoder_test.go b/decoder_test.go index 68d8903..c03e380 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -207,7 +207,7 @@ func validateDecoding(t *testing.T, tests map[string]any) { for inputStr, expected := range tests { inputBytes, err := hex.DecodeString(inputStr) require.NoError(t, err) - d := decoder{inputBytes} + d := decoder{buffer: inputBytes} var result any _, err = d.decode(0, reflect.ValueOf(&result), 0) @@ -223,7 +223,7 @@ func validateDecoding(t *testing.T, tests map[string]any) { func TestPointers(t *testing.T) { bytes, err := os.ReadFile(testFile("maps-with-pointers.raw")) require.NoError(t, err) - d := decoder{bytes} + d := decoder{buffer: bytes} expected := map[uint]map[string]string{ 0: {"long_key": "long_value1"}, diff --git a/deserializer_test.go b/deserializer_test.go index c68a9d5..a6e3b70 100644 --- a/deserializer_test.go +++ b/deserializer_test.go @@ -2,7 +2,7 @@ package maxminddb import ( "math/big" - "net" + "net/netip" "testing" "github.com/stretchr/testify/require" @@ -13,7 +13,7 @@ func TestDecodingToDeserializer(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) dser := testDeserializer{} - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&dser) require.NoError(t, err, "unexpected error while doing lookup: %v", err) checkDecodingToInterface(t, dser.rv) diff --git a/example_test.go b/example_test.go index 52e878e..0f15c1e 100644 --- a/example_test.go +++ b/example_test.go @@ -3,9 +3,9 @@ package maxminddb_test import ( "fmt" "log" - "net" + "net/netip" - "github.com/oschwald/maxminddb-golang" + "github.com/oschwald/maxminddb-golang/v2" ) // This example shows how to decode to a struct. @@ -16,7 +16,7 @@ func ExampleReader_Lookup_struct() { } defer db.Close() - ip := net.ParseIP("81.2.69.142") + addr := netip.MustParseAddr("81.2.69.142") var record struct { Country struct { @@ -24,7 +24,7 @@ func ExampleReader_Lookup_struct() { } `maxminddb:"country"` } // Or any appropriate struct - err = db.Lookup(ip, &record) + err = db.Lookup(addr).Decode(&record) if err != nil { log.Panic(err) } @@ -41,10 +41,10 @@ func ExampleReader_Lookup_interface() { } defer db.Close() - ip := net.ParseIP("81.2.69.142") + addr := netip.MustParseAddr("81.2.69.142") var record any - err = db.Lookup(ip, &record) + err = db.Lookup(addr).Decode(&record) if err != nil { log.Panic(err) } @@ -63,20 +63,16 @@ func ExampleReader_Networks() { } defer db.Close() - networks := db.Networks(maxminddb.SkipAliasedNetworks) - for networks.Next() { + for result := range db.Networks() { record := struct { Domain string `maxminddb:"connection_type"` }{} - subnet, err := networks.Network(&record) + err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", subnet.String(), record.Domain) - } - if networks.Err() != nil { - log.Panic(networks.Err()) + fmt.Printf("%s: %s\n", result.Prefix(), record.Domain) } // Output: // 1.0.0.0/24: Cable/DSL @@ -114,25 +110,20 @@ func ExampleReader_NetworksWithin() { } defer db.Close() - _, network, err := net.ParseCIDR("1.0.0.0/8") + prefix, err := netip.ParsePrefix("1.0.0.0/8") if err != nil { log.Panic(err) } - networks := db.NetworksWithin(network, maxminddb.SkipAliasedNetworks) - for networks.Next() { + for result := range db.NetworksWithin(prefix) { record := struct { Domain string `maxminddb:"connection_type"` }{} - - subnet, err := networks.Network(&record) + err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", subnet.String(), record.Domain) - } - if networks.Err() != nil { - log.Panic(networks.Err()) + fmt.Printf("%s: %s\n", result.Prefix(), record.Domain) } // Output: diff --git a/go.mod b/go.mod index d56901d..fc7d1c9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ -module github.com/oschwald/maxminddb-golang +module github.com/oschwald/maxminddb-golang/v2 -go 1.21 +go 1.23 require ( github.com/stretchr/testify v1.9.0 diff --git a/reader.go b/reader.go index 2dc712f..c52dedc 100644 --- a/reader.go +++ b/reader.go @@ -5,17 +5,11 @@ import ( "bytes" "errors" "fmt" - "net" + "net/netip" "reflect" ) -const ( - // NotFound is returned by LookupOffset when a matched root record offset - // cannot be found. - NotFound = ^uintptr(0) - - dataSectionSeparatorSize = 16 -) +const dataSectionSeparatorSize = 16 var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com") @@ -61,7 +55,7 @@ func FromBytes(buffer []byte) (*Reader, error) { } metadataStart += len(metadataStartMarker) - metadataDecoder := decoder{buffer[metadataStart:]} + metadataDecoder := decoder{buffer: buffer[metadataStart:]} var metadata Metadata @@ -78,7 +72,7 @@ func FromBytes(buffer []byte) (*Reader, error) { return nil, newInvalidDatabaseError("the MaxMind DB contains invalid metadata") } d := decoder{ - buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], + buffer: buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], } nodeBuffer := buffer[:searchTreeSize] @@ -110,6 +104,7 @@ func FromBytes(buffer []byte) (*Reader, error) { func (r *Reader) setIPv4Start() { if r.Metadata.IPVersion != 6 { + r.ipv4StartBitDepth = 96 return } @@ -124,162 +119,82 @@ func (r *Reader) setIPv4Start() { r.ipv4StartBitDepth = i } -// Lookup retrieves the database record for ip and stores it in the value -// pointed to by result. If result is nil or not a pointer, an error is -// returned. If the data in the database record cannot be stored in result -// because of type differences, an UnmarshalTypeError is returned. If the -// database is invalid or otherwise cannot be read, an InvalidDatabaseError -// is returned. -func (r *Reader) Lookup(ip net.IP, result any) error { +// Lookup retrieves the database record for ip and returns Result, which can +// be used to decode the data.. +func (r *Reader) Lookup(ip netip.Addr) Result { if r.buffer == nil { - return errors.New("cannot call Lookup on a closed database") - } - pointer, _, _, err := r.lookupPointer(ip) - if pointer == 0 || err != nil { - return err + return Result{err: errors.New("cannot call Lookup on a closed database")} } - return r.retrieveData(pointer, result) -} - -// LookupNetwork retrieves the database record for ip and stores it in the -// value pointed to by result. The network returned is the network associated -// with the data record in the database. The ok return value indicates whether -// the database contained a record for the ip. -// -// If result is nil or not a pointer, an error is returned. If the data in the -// database record cannot be stored in result because of type differences, an -// UnmarshalTypeError is returned. If the database is invalid or otherwise -// cannot be read, an InvalidDatabaseError is returned. -func (r *Reader) LookupNetwork( - ip net.IP, - result any, -) (network *net.IPNet, ok bool, err error) { - if r.buffer == nil { - return nil, false, errors.New("cannot call Lookup on a closed database") - } - pointer, prefixLength, ip, err := r.lookupPointer(ip) - - network = r.cidr(ip, prefixLength) - if pointer == 0 || err != nil { - return network, false, err - } - - return network, true, r.retrieveData(pointer, result) -} - -// LookupOffset maps an argument net.IP to a corresponding record offset in the -// database. NotFound is returned if no such record is found, and a record may -// otherwise be extracted by passing the returned offset to Decode. LookupOffset -// is an advanced API, which exists to provide clients with a means to cache -// previously-decoded records. -func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) { - if r.buffer == nil { - return 0, errors.New("cannot call LookupOffset on a closed database") + pointer, prefixLen, err := r.lookupPointer(ip) + if err != nil { + return Result{ + ip: ip, + prefixLen: uint8(prefixLen), + err: err, + } } - pointer, _, _, err := r.lookupPointer(ip) - if pointer == 0 || err != nil { - return NotFound, err + if pointer == 0 { + return Result{ + ip: ip, + prefixLen: uint8(prefixLen), + offset: notFound, + } } - return r.resolveDataPointer(pointer) -} - -func (r *Reader) cidr(ip net.IP, prefixLength int) *net.IPNet { - // This is necessary as the node that the IPv4 start is at may - // be at a bit depth that is less that 96, i.e., ipv4Start points - // to a leaf node. For instance, if a record was inserted at ::/8, - // the ipv4Start would point directly at the leaf node for the - // record and would have a bit depth of 8. This would not happen - // with databases currently distributed by MaxMind as all of them - // have an IPv4 subtree that is greater than a single node. - if r.Metadata.IPVersion == 6 && - len(ip) == net.IPv4len && - r.ipv4StartBitDepth != 96 { - return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(r.ipv4StartBitDepth, 128)} + offset, err := r.resolveDataPointer(pointer) + return Result{ + decoder: r.decoder, + ip: ip, + offset: uint(offset), + prefixLen: uint8(prefixLen), + err: err, } - - mask := net.CIDRMask(prefixLength, len(ip)*8) - return &net.IPNet{IP: ip.Mask(mask), Mask: mask} } -// Decode the record at |offset| into |result|. The result value pointed to -// must be a data value that corresponds to a record in the database. This may -// include a struct representation of the data, a map capable of holding the -// data or an empty any value. -// -// If result is a pointer to a struct, the struct need not include a field -// for every value that may be in the database. If a field is not present in -// the structure, the decoder will not decode that field, reducing the time -// required to decode the record. -// -// As a special case, a struct field of type uintptr will be used to capture -// the offset of the value. Decode may later be used to extract the stored -// value from the offset. MaxMind DBs are highly normalized: for example in -// the City database, all records of the same country will reference a -// single representative record for that country. This uintptr behavior allows -// clients to leverage this normalization in their own sub-record caching. -func (r *Reader) Decode(offset uintptr, result any) error { +// LookupOffset returns the Result for the specified offset. Note that +// netip.Prefix returned by Networks will be invalid when using LookupOffset. +func (r *Reader) LookupOffset(offset uintptr) Result { if r.buffer == nil { - return errors.New("cannot call Decode on a closed database") - } - return r.decode(offset, result) -} - -func (r *Reader) decode(offset uintptr, result any) error { - rv := reflect.ValueOf(result) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return errors.New("result param must be a pointer") + return Result{err: errors.New("cannot call Decode on a closed database")} } - if dser, ok := result.(deserializer); ok { - _, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0, false) - return err - } - - _, err := r.decoder.decode(uint(offset), rv, 0) - return err + return Result{decoder: r.decoder, offset: uint(offset)} } -func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) { - if ip == nil { - return 0, 0, nil, errors.New("IP passed to Lookup cannot be nil") - } +var zeroIP = netip.MustParseAddr("::") - ipV4Address := ip.To4() - if ipV4Address != nil { - ip = ipV4Address - } - if len(ip) == 16 && r.Metadata.IPVersion == 4 { - return 0, 0, ip, fmt.Errorf( +func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { + if r.Metadata.IPVersion == 4 && ip.Is6() { + return 0, 0, fmt.Errorf( "error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", ip.String(), ) } - bitCount := uint(len(ip) * 8) - - var node uint - if bitCount == 32 { - node = r.ipv4Start - } - node, prefixLength := r.traverseTree(ip, node, bitCount) + node, prefixLength := r.traverseTree(ip, 0, 128) nodeCount := r.Metadata.NodeCount if node == nodeCount { // Record is empty - return 0, prefixLength, ip, nil + return 0, prefixLength, nil } else if node > nodeCount { - return node, prefixLength, ip, nil + return node, prefixLength, nil } - return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree") + return 0, prefixLength, newInvalidDatabaseError("invalid node in search tree") } -func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) { +func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) { + i := 0 + if ip.Is4() { + i = r.ipv4StartBitDepth + node = r.ipv4Start + } nodeCount := r.Metadata.NodeCount - i := uint(0) - for ; i < bitCount && node < nodeCount; i++ { - bit := uint(1) & (uint(ip[i>>3]) >> (7 - (i % 8))) + ip16 := ip.As16() + + for ; i < stopBit && node < nodeCount; i++ { + bit := uint(1) & (uint(ip16[i>>3]) >> (7 - (i % 8))) offset := node * r.nodeOffsetMult if bit == 0 { @@ -289,15 +204,7 @@ func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) { } } - return node, int(i) -} - -func (r *Reader) retrieveData(pointer uint, result any) error { - offset, err := r.resolveDataPointer(pointer) - if err != nil { - return err - } - return r.decode(offset, result) + return node, i } func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { diff --git a/reader_test.go b/reader_test.go index c9d287d..1908635 100644 --- a/reader_test.go +++ b/reader_test.go @@ -6,6 +6,7 @@ import ( "math/big" "math/rand" "net" + "net/netip" "os" "path/filepath" "testing" @@ -19,19 +20,21 @@ func TestReader(t *testing.T) { for _, recordSize := range []uint{24, 28, 32} { for _, ipVersion := range []uint{4, 6} { fileName := fmt.Sprintf( - testFile("MaxMind-DB-test-ipv%d-%d.mmdb"), + "MaxMind-DB-test-ipv%d-%d.mmdb", ipVersion, recordSize, ) - reader, err := Open(fileName) - require.NoError(t, err, "unexpected error while opening database: %v", err) - checkMetadata(t, reader, ipVersion, recordSize) - - if ipVersion == 4 { - checkIpv4(t, reader) - } else { - checkIpv6(t, reader) - } + t.Run(fileName, func(t *testing.T) { + reader, err := Open(testFile(fileName)) + require.NoError(t, err, "unexpected error while opening database: %v", err) + checkMetadata(t, reader, ipVersion, recordSize) + + if ipVersion == 4 { + checkIpv4(t, reader) + } else { + checkIpv6(t, reader) + } + }) } } } @@ -97,95 +100,95 @@ func TestLookupNetwork(t *testing.T) { } tests := []struct { - IP net.IP - DBFile string - ExpectedCIDR string - ExpectedRecord any - ExpectedOK bool + IP netip.Addr + DBFile string + ExpectedNetwork string + ExpectedRecord any + ExpectedFound bool }{ { - IP: net.ParseIP("1.1.1.1"), - DBFile: "MaxMind-DB-test-ipv6-32.mmdb", - ExpectedCIDR: "1.0.0.0/8", - ExpectedRecord: nil, - ExpectedOK: false, + IP: netip.MustParseAddr("1.1.1.1"), + DBFile: "MaxMind-DB-test-ipv6-32.mmdb", + ExpectedNetwork: "1.0.0.0/8", + ExpectedRecord: nil, + ExpectedFound: false, }, { - IP: net.ParseIP("::1:ffff:ffff"), - DBFile: "MaxMind-DB-test-ipv6-24.mmdb", - ExpectedCIDR: "::1:ffff:ffff/128", - ExpectedRecord: map[string]any{"ip": "::1:ffff:ffff"}, - ExpectedOK: true, + IP: netip.MustParseAddr("::1:ffff:ffff"), + DBFile: "MaxMind-DB-test-ipv6-24.mmdb", + ExpectedNetwork: "::1:ffff:ffff/128", + ExpectedRecord: map[string]any{"ip": "::1:ffff:ffff"}, + ExpectedFound: true, }, { - IP: net.ParseIP("::2:0:1"), - DBFile: "MaxMind-DB-test-ipv6-24.mmdb", - ExpectedCIDR: "::2:0:0/122", - ExpectedRecord: map[string]any{"ip": "::2:0:0"}, - ExpectedOK: true, + IP: netip.MustParseAddr("::2:0:1"), + DBFile: "MaxMind-DB-test-ipv6-24.mmdb", + ExpectedNetwork: "::2:0:0/122", + ExpectedRecord: map[string]any{"ip": "::2:0:0"}, + ExpectedFound: true, }, { - IP: net.ParseIP("1.1.1.1"), - DBFile: "MaxMind-DB-test-ipv4-24.mmdb", - ExpectedCIDR: "1.1.1.1/32", - ExpectedRecord: map[string]any{"ip": "1.1.1.1"}, - ExpectedOK: true, + IP: netip.MustParseAddr("1.1.1.1"), + DBFile: "MaxMind-DB-test-ipv4-24.mmdb", + ExpectedNetwork: "1.1.1.1/32", + ExpectedRecord: map[string]any{"ip": "1.1.1.1"}, + ExpectedFound: true, }, { - IP: net.ParseIP("1.1.1.3"), - DBFile: "MaxMind-DB-test-ipv4-24.mmdb", - ExpectedCIDR: "1.1.1.2/31", - ExpectedRecord: map[string]any{"ip": "1.1.1.2"}, - ExpectedOK: true, + IP: netip.MustParseAddr("1.1.1.3"), + DBFile: "MaxMind-DB-test-ipv4-24.mmdb", + ExpectedNetwork: "1.1.1.2/31", + ExpectedRecord: map[string]any{"ip": "1.1.1.2"}, + ExpectedFound: true, }, { - IP: net.ParseIP("1.1.1.3"), - DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "1.1.1.0/24", - ExpectedRecord: decoderRecord, - ExpectedOK: true, + IP: netip.MustParseAddr("1.1.1.3"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedNetwork: "1.1.1.0/24", + ExpectedRecord: decoderRecord, + ExpectedFound: true, }, { - IP: net.ParseIP("::ffff:1.1.1.128"), - DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "1.1.1.0/24", - ExpectedRecord: decoderRecord, - ExpectedOK: true, + IP: netip.MustParseAddr("::ffff:1.1.1.128"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedNetwork: "::ffff:1.1.1.0/120", + ExpectedRecord: decoderRecord, + ExpectedFound: true, }, { - IP: net.ParseIP("::1.1.1.128"), - DBFile: "MaxMind-DB-test-decoder.mmdb", - ExpectedCIDR: "::101:100/120", - ExpectedRecord: decoderRecord, - ExpectedOK: true, + IP: netip.MustParseAddr("::1.1.1.128"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedNetwork: "::101:100/120", + ExpectedRecord: decoderRecord, + ExpectedFound: true, }, { - IP: net.ParseIP("200.0.2.1"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "::/64", - ExpectedRecord: "::0/64", - ExpectedOK: true, + IP: netip.MustParseAddr("200.0.2.1"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "::/64", + ExpectedRecord: "::0/64", + ExpectedFound: true, }, { - IP: net.ParseIP("::200.0.2.1"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "::/64", - ExpectedRecord: "::0/64", - ExpectedOK: true, + IP: netip.MustParseAddr("::200.0.2.1"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "::/64", + ExpectedRecord: "::0/64", + ExpectedFound: true, }, { - IP: net.ParseIP("0:0:0:0:ffff:ffff:ffff:ffff"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "::/64", - ExpectedRecord: "::0/64", - ExpectedOK: true, + IP: netip.MustParseAddr("0:0:0:0:ffff:ffff:ffff:ffff"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "::/64", + ExpectedRecord: "::0/64", + ExpectedFound: true, }, { - IP: net.ParseIP("ef00::"), - DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", - ExpectedCIDR: "8000::/1", - ExpectedRecord: nil, - ExpectedOK: false, + IP: netip.MustParseAddr("ef00::"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedNetwork: "8000::/1", + ExpectedRecord: nil, + ExpectedFound: false, }, } @@ -195,10 +198,12 @@ func TestLookupNetwork(t *testing.T) { reader, err := Open(testFile(test.DBFile)) require.NoError(t, err) - network, ok, err := reader.LookupNetwork(test.IP, &record) - require.NoError(t, err) - assert.Equal(t, test.ExpectedOK, ok) - assert.Equal(t, test.ExpectedCIDR, network.String()) + result := reader.Lookup(test.IP) + require.NoError(t, result.Err()) + assert.Equal(t, test.ExpectedFound, result.Found()) + assert.Equal(t, test.ExpectedNetwork, result.Prefix().String()) + + require.NoError(t, result.Decode(&record)) assert.Equal(t, test.ExpectedRecord, record) }) } @@ -209,7 +214,7 @@ func TestDecodingToInterface(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var recordInterface any - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &recordInterface) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&recordInterface) require.NoError(t, err, "unexpected error while doing lookup: %v", err) checkDecodingToInterface(t, recordInterface) @@ -295,24 +300,49 @@ func TestDecoder(t *testing.T) { { // Directly lookup and decode. - var result TestType - require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) - verify(result) + var testV TestType + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&testV)) + verify(testV) } { // Lookup record offset, then Decode. - var result TestType - offset, err := reader.LookupOffset(net.ParseIP("::1.1.1.0")) - require.NoError(t, err) - assert.NotEqual(t, NotFound, offset) - - require.NoError(t, reader.Decode(offset, &result)) - verify(result) + var testV TestType + result := reader.Lookup(netip.MustParseAddr("::1.1.1.0")) + require.NoError(t, result.Err()) + require.True(t, result.Found()) + + res := reader.LookupOffset(result.Offset()) + require.NoError(t, res.Decode(&testV)) + verify(testV) } require.NoError(t, reader.Close()) } +func TestDecodePath(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + + result := reader.Lookup(netip.MustParseAddr("::1.1.1.0")) + require.NoError(t, result.Err()) + + var u16 uint16 + + require.NoError(t, result.DecodePath(&u16, "uint16")) + + assert.Equal(t, uint16(100), u16) + + var u uint + require.NoError(t, result.DecodePath(&u, "array", 0)) + assert.Equal(t, uint(1), u) + + require.NoError(t, result.DecodePath(&u, "array", 2)) + assert.Equal(t, uint(3), u) + + require.NoError(t, result.DecodePath(&u, "map", "mapX", "arrayX", 1)) + assert.Equal(t, uint(8), u) +} + type TestInterface interface { method() bool } @@ -327,7 +357,7 @@ func TestStructInterface(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) - require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)) assert.True(t, result.method()) } @@ -338,7 +368,7 @@ func TestNonEmptyNilInterface(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) assert.Equal( t, "maxminddb: cannot unmarshal map into type maxminddb.TestInterface", @@ -361,7 +391,7 @@ func TestEmbeddedStructAsInterface(t *testing.T) { db, err := Open(testFile("GeoIP2-ISP-Test.mmdb")) require.NoError(t, err) - require.NoError(t, db.Lookup(net.ParseIP("1.128.0.0"), &result)) + require.NoError(t, db.Lookup(netip.MustParseAddr("1.128.0.0")).Decode(&result)) } type BoolInterface interface { @@ -387,7 +417,7 @@ func TestValueTypeInterface(t *testing.T) { // although it would be nice to support cases like this, I am not sure it // is possible to do so in a general way. - assert.Error(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + assert.Error(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)) } type NestedMapX struct { @@ -429,7 +459,7 @@ func TestComplexStructWithNestingAndPointer(t *testing.T) { var result TestPointerType - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) require.NoError(t, err) assert.Equal(t, []uint{uint(1), uint(2), uint(3)}, *result.Array) @@ -461,7 +491,7 @@ func TestNestedMapDecode(t *testing.T) { var r map[string]map[string]any - require.NoError(t, db.Lookup(net.ParseIP("89.160.20.128"), &r)) + require.NoError(t, db.Lookup(netip.MustParseAddr("89.160.20.128")).Decode(&r)) assert.Equal( t, @@ -519,9 +549,9 @@ func TestNestedOffsetDecode(t *testing.T) { db, err := Open(testFile("GeoIP2-City-Test.mmdb")) require.NoError(t, err) - off, err := db.LookupOffset(net.ParseIP("81.2.69.142")) - assert.NotEqual(t, NotFound, off) - require.NoError(t, err) + result := db.Lookup(netip.MustParseAddr("81.2.69.142")) + require.NoError(t, result.Err()) + require.True(t, result.Found()) var root struct { CountryOffset uintptr `maxminddb:"country"` @@ -534,21 +564,25 @@ func TestNestedOffsetDecode(t *testing.T) { TimeZoneOffset uintptr `maxminddb:"time_zone"` } `maxminddb:"location"` } - require.NoError(t, db.Decode(off, &root)) + res := db.LookupOffset(result.Offset()) + require.NoError(t, res.Decode(&root)) assert.InEpsilon(t, 51.5142, root.Location.Latitude, 1e-10) var longitude float64 - require.NoError(t, db.Decode(root.Location.LongitudeOffset, &longitude)) + res = db.LookupOffset(root.Location.LongitudeOffset) + require.NoError(t, res.Decode(&longitude)) assert.InEpsilon(t, -0.0931, longitude, 1e-10) var timeZone string - require.NoError(t, db.Decode(root.Location.TimeZoneOffset, &timeZone)) + res = db.LookupOffset(root.Location.TimeZoneOffset) + require.NoError(t, res.Decode(&timeZone)) assert.Equal(t, "Europe/London", timeZone) var country struct { IsoCode string `maxminddb:"iso_code"` } - require.NoError(t, db.Decode(root.CountryOffset, &country)) + res = db.LookupOffset(root.CountryOffset) + require.NoError(t, res.Decode(&country)) assert.Equal(t, "GB", country.IsoCode) require.NoError(t, db.Close()) @@ -561,7 +595,7 @@ func TestDecodingUint16IntoInt(t *testing.T) { var result struct { Uint16 int `maxminddb:"uint16"` } - err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result) require.NoError(t, err) assert.Equal(t, 100, result.Uint16) @@ -572,7 +606,7 @@ func TestIpv6inIpv4(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var result TestType - err = reader.Lookup(net.ParseIP("2001::"), &result) + err = reader.Lookup(netip.MustParseAddr("2001::")).Decode(&result) var emptyResult TestType assert.Equal(t, emptyResult, result) @@ -589,7 +623,7 @@ func TestBrokenDoubleDatabase(t *testing.T) { require.NoError(t, err, "unexpected error while opening database: %v", err) var result any - err = reader.Lookup(net.ParseIP("2001:220::"), &result) + err = reader.Lookup(netip.MustParseAddr("2001:220::")).Decode(&result) expected := newInvalidDatabaseError( "the MaxMind DB file's data section contains bad data (float 64 size of 2)", @@ -622,35 +656,36 @@ func TestDecodingToNonPointer(t *testing.T) { require.NoError(t, err) var recordInterface any - err = reader.Lookup(net.ParseIP("::1.1.1.0"), recordInterface) + err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(recordInterface) assert.Equal(t, "result param must be a pointer", err.Error()) require.NoError(t, reader.Close(), "error on close") } -func TestNilLookup(t *testing.T) { - reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) - require.NoError(t, err) +// func TestNilLookup(t *testing.T) { +// reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) +// require.NoError(t, err) - var recordInterface any - err = reader.Lookup(nil, recordInterface) - assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error()) - require.NoError(t, reader.Close(), "error on close") -} +// var recordInterface any +// err = reader.Lookup(nil).Decode( recordInterface) +// assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error()) +// require.NoError(t, reader.Close(), "error on close") +// } func TestUsingClosedDatabase(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) require.NoError(t, err) require.NoError(t, reader.Close()) - var recordInterface any + addr := netip.MustParseAddr("::") - err = reader.Lookup(nil, recordInterface) - assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) + result := reader.Lookup(addr) + assert.Equal(t, "cannot call Lookup on a closed database", result.Err().Error()) - _, err = reader.LookupOffset(nil) - assert.Equal(t, "cannot call LookupOffset on a closed database", err.Error()) + var recordInterface any + err = reader.Lookup(addr).Decode(recordInterface) + assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) - err = reader.Decode(0, recordInterface) + err = reader.LookupOffset(0).Decode(recordInterface) assert.Equal(t, "cannot call Decode on a closed database", err.Error()) } @@ -682,10 +717,10 @@ func checkMetadata(t *testing.T, reader *Reader, ipVersion, recordSize uint) { func checkIpv4(t *testing.T, reader *Reader) { for i := uint(0); i < 6; i++ { address := fmt.Sprintf("1.1.1.%d", uint(1)<> 24) ip[1] = byte(num >> 16) ip[2] = byte(num >> 8) ip[3] = byte(num) + v, _ := netip.AddrFromSlice(ip) + return v } func testFile(file string) string { diff --git a/result.go b/result.go new file mode 100644 index 0000000..e50bb80 --- /dev/null +++ b/result.go @@ -0,0 +1,143 @@ +package maxminddb + +import ( + "errors" + "math" + "net/netip" + "reflect" +) + +const notFound uint = math.MaxUint + +type Result struct { + ip netip.Addr + err error + decoder decoder + offset uint + prefixLen uint8 +} + +// Decode unmarshals the data from the data section into the value pointed to +// by v. If v is nil or not a pointer, an error is returned. If the data in +// the database record cannot be stored in v because of type differences, an +// UnmarshalTypeError is returned. If the database is invalid or otherwise +// cannot be read, an InvalidDatabaseError is returned. +// +// An error will also be returned if there was an error during the +// Reader.Lookup call. +// +// If the Reader.Lookup call did not find a value for the IP address, no error +// will be returned and v will be unchanged. +func (r Result) Decode(v any) error { + if r.err != nil { + return r.err + } + if r.offset == notFound { + return nil + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("result param must be a pointer") + } + + if dser, ok := v.(deserializer); ok { + _, err := r.decoder.decodeToDeserializer(r.offset, dser, 0, false) + return err + } + + _, err := r.decoder.decode(r.offset, rv, 0) + return err +} + +// DecodePath unmarshals a value from data section into v, following the +// specified path. +// +// The v parameter should be a pointer to the value where the decoded data +// will be stored. If v is nil or not a pointer, an error is returned. If the +// data in the database record cannot be stored in v because of type +// differences, an UnmarshalTypeError is returned. +// +// The path is a variadic list of keys (strings) and/or indices (ints) that +// describe the nested structure to traverse in the data to reach the desired +// value. +// +// For maps, string path elements are used as keys. +// For arrays, int path elements are used as indices. +// +// If the path is empty, the entire data structure is decoded into v. +// +// Returns an error if: +// - the path is invalid +// - the data cannot be decoded into the type of v +// - v is not a pointer or the database record cannot be stored in v due to +// type mismatch +// - the Result does not contain valid data +// +// Example usage: +// +// var city string +// err := result.DecodePath(&city, "location", "city", "names", "en") +// +// var geonameID int +// err := result.DecodePath(&geonameID, "subdivisions", 0, "geoname_id") +func (r Result) DecodePath(v any, path ...any) error { + if r.err != nil { + return r.err + } + if r.offset == notFound { + return nil + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("result param must be a pointer") + } + return r.decoder.decodePath(r.offset, path, rv) +} + +// Err provides a way to check whether there was an error during the lookup +// without calling Result.Decode. If there was an error, it will also be +// returned from Result.Decode. +func (r Result) Err() error { + return r.err +} + +// Found will return true if the IP was found in the search tree. It will +// return false if the IP was not found or if there was an error. +func (r Result) Found() bool { + return r.err == nil && r.offset != notFound +} + +// Offset returns the offset of the record in the database. This can be +// passed to (*Reader).LookupOffset. It can also be used as a unique +// identifier for the data record in the particular database to cache the data +// record across lookups. Note that while the offset uniquely identifies the +// data record, other data in Result may differ between lookups. The offset +// is only valid for the current database version. If you update the database +// file, you must invalidate any cache associated with the previous version. +func (r Result) Offset() uintptr { + return uintptr(r.offset) +} + +// Prefix returns the netip.Prefix representing the network associated with +// the data record in the database. +func (r Result) Prefix() netip.Prefix { + ip := r.ip + prefixLen := int(r.prefixLen) + + if ip.Is4() { + // This is necessary as the node that the IPv4 start is at may + // be at a bit depth that is less that 96, i.e., ipv4Start points + // to a leaf node. For instance, if a record was inserted at ::/8, + // the ipv4Start would point directly at the leaf node for the + // record and would have a bit depth of 8. This would not happen + // with databases currently distributed by MaxMind as all of them + // have an IPv4 subtree that is greater than a single node. + if prefixLen < 96 { + return netip.PrefixFrom(zeroIP, prefixLen) + } + prefixLen -= 96 + } + + prefix, _ := ip.Prefix(prefixLen) + return prefix +} diff --git a/traverse.go b/traverse.go index 90073e2..6a748db 100644 --- a/traverse.go +++ b/traverse.go @@ -2,210 +2,192 @@ package maxminddb import ( "fmt" - "net" + "net/netip" + + // comment to prevent gofumpt from randomly moving iter. + "iter" ) // Internal structure used to keep track of nodes we still need to visit. type netNode struct { - ip net.IP + ip netip.Addr bit uint pointer uint } -// Networks represents a set of subnets that we are iterating over. -type Networks struct { - err error - reader *Reader - nodes []netNode - lastNode netNode - skipAliasedNetworks bool +type networkOptions struct { + includeAliasedNetworks bool } var ( - allIPv4 = &net.IPNet{IP: make(net.IP, 4), Mask: net.CIDRMask(0, 32)} - allIPv6 = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(0, 128)} + allIPv4 = netip.MustParsePrefix("0.0.0.0/0") + allIPv6 = netip.MustParsePrefix("::/0") ) // NetworksOption are options for Networks and NetworksWithin. -type NetworksOption func(*Networks) +type NetworksOption func(*networkOptions) -// SkipAliasedNetworks is an option for Networks and NetworksWithin that -// makes them not iterate over aliases of the IPv4 subtree in an IPv6 +// IncludeAliasedNetworks is an option for Networks and NetworksWithin +// that makes them iterate over aliases of the IPv4 subtree in an IPv6 // database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. -// -// You most likely want to set this. The only reason it isn't the default -// behavior is to provide backwards compatibility to existing users. -func SkipAliasedNetworks(networks *Networks) { - networks.skipAliasedNetworks = true +func IncludeAliasedNetworks(networks *networkOptions) { + networks.includeAliasedNetworks = true } // Networks returns an iterator that can be used to traverse all networks in // the database. // // Please note that a MaxMind DB may map IPv4 networks into several locations -// in an IPv6 database. This iterator will iterate over all of these locations -// separately. To only iterate over the IPv4 networks once, use the -// SkipAliasedNetworks option. -func (r *Reader) Networks(options ...NetworksOption) *Networks { - var networks *Networks +// in an IPv6 database. This iterator will only iterate over these once by +// default. To iterate over all the IPv4 network locations, use the +// IncludeAliasedNetworks option. +func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] { if r.Metadata.IPVersion == 6 { - networks = r.NetworksWithin(allIPv6, options...) - } else { - networks = r.NetworksWithin(allIPv4, options...) + return r.NetworksWithin(allIPv6, options...) } - - return networks + return r.NetworksWithin(allIPv4, options...) } // NetworksWithin returns an iterator that can be used to traverse all networks -// in the database which are contained in a given network. +// in the database which are contained in a given prefix. // // Please note that a MaxMind DB may map IPv4 networks into several locations // in an IPv6 database. This iterator will iterate over all of these locations // separately. To only iterate over the IPv4 networks once, use the // SkipAliasedNetworks option. // -// If the provided network is contained within a network in the database, the +// If the provided prefix is contained within a network in the database, the // iterator will iterate over exactly one network, the containing network. -func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) *Networks { - if r.Metadata.IPVersion == 4 && network.IP.To4() == nil { - return &Networks{ - err: fmt.Errorf( - "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", - network.String(), - ), +func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] { + return func(yield func(Result) bool) { + if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { + yield(Result{ + err: fmt.Errorf( + "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", + prefix, + ), + }) + return } - } - networks := &Networks{reader: r} - for _, option := range options { - option(networks) - } - - ip := network.IP - prefixLength, _ := network.Mask.Size() - - if r.Metadata.IPVersion == 6 && len(ip) == net.IPv4len { - if networks.skipAliasedNetworks { - ip = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ip[0], ip[1], ip[2], ip[3]} - } else { - ip = ip.To16() + n := &networkOptions{} + for _, option := range options { + option(n) } - prefixLength += 96 - } - - pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) - - // We could skip this when bit >= prefixLength if we assume that the network - // passed in is in canonical form. However, given that this may not be the - // case, it is safest to always take the mask. If this is hot code at some - // point, we could eliminate the allocation of the net.IPMask by zeroing - // out the bits in ip directly. - ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) - networks.nodes = []netNode{ - { - ip: ip, - bit: uint(bit), - pointer: pointer, - }, - } - - return networks -} -// Next prepares the next network for reading with the Network method. It -// returns true if there is another network to be processed and false if there -// are no more networks or if there is an error. -func (n *Networks) Next() bool { - if n.err != nil { - return false - } - for len(n.nodes) > 0 { - node := n.nodes[len(n.nodes)-1] - n.nodes = n.nodes[:len(n.nodes)-1] - - for node.pointer != n.reader.Metadata.NodeCount { - // This skips IPv4 aliases without hardcoding the networks that the writer - // currently aliases. - if n.skipAliasedNetworks && n.reader.ipv4Start != 0 && - node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) { - break - } - - if node.pointer > n.reader.Metadata.NodeCount { - n.lastNode = node - return true - } - ipRight := make(net.IP, len(node.ip)) - copy(ipRight, node.ip) - if len(ipRight) <= int(node.bit>>3) { - n.err = newInvalidDatabaseError( - "invalid search tree at %v/%v", ipRight, node.bit) - return false - } - ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) + ip := prefix.Addr() + netIP := ip + stopBit := prefix.Bits() + if ip.Is4() { + netIP = v4ToV16(ip) + stopBit += 96 + } - offset := node.pointer * n.reader.nodeOffsetMult - rightPointer := n.reader.nodeReader.readRight(offset) + pointer, bit := r.traverseTree(ip, 0, stopBit) - node.bit++ - n.nodes = append(n.nodes, netNode{ - pointer: rightPointer, - ip: ipRight, - bit: node.bit, + prefix, err := netIP.Prefix(bit) + if err != nil { + yield(Result{ + ip: ip, + prefixLen: uint8(bit), + err: fmt.Errorf("prefixing %s with %d", netIP, bit), }) + } - node.pointer = n.reader.nodeReader.readLeft(offset) + nodes := make([]netNode, 0, 64) + nodes = append(nodes, + netNode{ + ip: prefix.Addr(), + bit: uint(bit), + pointer: pointer, + }, + ) + + for len(nodes) > 0 { + node := nodes[len(nodes)-1] + nodes = nodes[:len(nodes)-1] + + for node.pointer != r.Metadata.NodeCount { + // This skips IPv4 aliases without hardcoding the networks that the writer + // currently aliases. + if !n.includeAliasedNetworks && r.ipv4Start != 0 && + node.pointer == r.ipv4Start && !isInIPv4Subtree(node.ip) { + break + } + + if node.pointer > r.Metadata.NodeCount { + ip := node.ip + if isInIPv4Subtree(ip) { + ip = v6ToV4(ip) + } + + offset, err := r.resolveDataPointer(node.pointer) + ok := yield(Result{ + decoder: r.decoder, + ip: ip, + offset: uint(offset), + prefixLen: uint8(node.bit), + err: err, + }) + if !ok { + return + } + break + } + ipRight := node.ip.As16() + if len(ipRight) <= int(node.bit>>3) { + displayAddr := node.ip + if isInIPv4Subtree(node.ip) { + displayAddr = v6ToV4(displayAddr) + } + + res := Result{ + ip: displayAddr, + prefixLen: uint8(node.bit), + } + res.err = newInvalidDatabaseError( + "invalid search tree at %s", res.Prefix()) + + yield(res) + + return + } + ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) + + offset := node.pointer * r.nodeOffsetMult + rightPointer := r.nodeReader.readRight(offset) + + node.bit++ + nodes = append(nodes, netNode{ + pointer: rightPointer, + ip: netip.AddrFrom16(ipRight), + bit: node.bit, + }) + + node.pointer = r.nodeReader.readLeft(offset) + } } } - - return false } -// Network returns the current network or an error if there is a problem -// decoding the data for the network. It takes a pointer to a result value to -// decode the network's data into. -func (n *Networks) Network(result any) (*net.IPNet, error) { - if n.err != nil { - return nil, n.err - } - if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { - return nil, err - } - - ip := n.lastNode.ip - prefixLength := int(n.lastNode.bit) - - // We do this because uses of SkipAliasedNetworks expect the IPv4 networks - // to be returned as IPv4 networks. If we are not skipping aliased - // networks, then the user will get IPv4 networks from the ::FFFF:0:0/96 - // network as Go automatically converts those. - if n.skipAliasedNetworks && isInIPv4Subtree(ip) { - ip = ip[12:] - prefixLength -= 96 - } +var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next() - return &net.IPNet{ - IP: ip, - Mask: net.CIDRMask(prefixLength, len(ip)*8), - }, nil +// isInIPv4Subtree returns true if the IP is in the database's IPv4 subtree. +func isInIPv4Subtree(ip netip.Addr) bool { + return ip.Is4() || ip.Less(ipv4SubtreeBoundary) } -// Err returns an error, if any, that was encountered during iteration. -func (n *Networks) Err() error { - return n.err +// We store IPv4 addresses at ::/96 for unclear reasons. +func v4ToV16(ip netip.Addr) netip.Addr { + b4 := ip.As4() + var b16 [16]byte + copy(b16[12:], b4[:]) + return netip.AddrFrom16(b16) } -// isInIPv4Subtree returns true if the IP is an IPv6 address in the database's -// IPv4 subtree. -func isInIPv4Subtree(ip net.IP) bool { - if len(ip) != 16 { - return false - } - for i := 0; i < 12; i++ { - if ip[i] != 0 { - return false - } - } - return true +// Converts an IPv4 address embedded in IPv6 to IPv4. +func v6ToV4(ip netip.Addr) netip.Addr { + b := ip.As16() + v, _ := netip.AddrFromSlice(b[12:]) + return v } diff --git a/traverse_test.go b/traverse_test.go index 00edfce..963d710 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -2,7 +2,7 @@ package maxminddb import ( "fmt" - "net" + "net/netip" "strconv" "strings" "testing" @@ -20,18 +20,18 @@ func TestNetworks(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - n := reader.Networks() - for n.Next() { + for result := range reader.Networks() { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - assert.Equal(t, record.IP, network.IP.String(), - "expected %s got %s", record.IP, network.IP.String(), + + network := result.Prefix() + assert.Equal(t, record.IP, network.Addr().String(), + "expected %s got %s", record.IP, network.Addr().String(), ) } - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) } } @@ -41,13 +41,14 @@ func TestNetworksWithInvalidSearchTree(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-broken-search-tree-24.mmdb")) require.NoError(t, err, "unexpected error while opening database: %v", err) - n := reader.Networks() - for n.Next() { + for result := range reader.Networks() { var record any - _, err := n.Network(&record) - require.NoError(t, err) + err = result.Decode(&record) + if err != nil { + break + } } - require.EqualError(t, n.Err(), "invalid search tree at 128.128.128.128/32") + require.EqualError(t, err, "invalid search tree at 128.128.128.128/32") require.NoError(t, reader.Close()) } @@ -140,7 +141,6 @@ var tests = []networkTest{ Expected: []string{ "::1:ffff:ffff/128", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "::/0", @@ -152,7 +152,6 @@ var tests = []networkTest{ "::2:0:50/125", "::2:0:58/127", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "::2:0:40/123", @@ -162,7 +161,6 @@ var tests = []networkTest{ "::2:0:50/125", "::2:0:58/127", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "0:0:0:0:0:ffff:ffff:ff00/120", @@ -192,29 +190,28 @@ var tests = []networkTest{ "1.1.1.16/28", "1.1.1.32/32", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "::/0", Database: "mixed", Expected: []string{ - "::101:101/128", - "::101:102/127", - "::101:104/126", - "::101:108/125", - "::101:110/124", - "::101:120/128", - "::1:ffff:ffff/128", - "::2:0:0/122", - "::2:0:40/124", - "::2:0:50/125", - "::2:0:58/127", "1.1.1.1/32", "1.1.1.2/31", "1.1.1.4/30", "1.1.1.8/29", "1.1.1.16/28", "1.1.1.32/32", + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + "::ffff:1.1.1.1/128", + "::ffff:1.1.1.2/127", + "::ffff:1.1.1.4/126", + "::ffff:1.1.1.8/125", + "::ffff:1.1.1.16/124", + "::ffff:1.1.1.32/128", "2001:0:101:101::/64", "2001:0:101:102::/63", "2001:0:101:104::/62", @@ -228,6 +225,7 @@ var tests = []networkTest{ "2002:101:110::/44", "2002:101:120::/48", }, + Options: []NetworksOption{IncludeAliasedNetworks}, }, { Network: "::/0", @@ -245,7 +243,6 @@ var tests = []networkTest{ "::2:0:50/125", "::2:0:58/127", }, - Options: []NetworksOption{SkipAliasedNetworks}, }, { Network: "1.1.1.16/28", @@ -281,33 +278,26 @@ func TestNetworksWithin(t *testing.T) { // We are purposely not using net.ParseCIDR so that we can pass in // values that aren't in canonical form. parts := strings.Split(v.Network, "/") - ip := net.ParseIP(parts[0]) - if v := ip.To4(); v != nil { - ip = v - } + ip, err := netip.ParseAddr(parts[0]) + require.NoError(t, err) prefixLength, err := strconv.Atoi(parts[1]) require.NoError(t, err) - mask := net.CIDRMask(prefixLength, len(ip)*8) - network := &net.IPNet{ - IP: ip, - Mask: mask, - } + network, err := ip.Prefix(prefixLength) + require.NoError(t, err) require.NoError(t, err) - n := reader.NetworksWithin(network, v.Options...) var innerIPs []string - for n.Next() { + for result := range reader.NetworksWithin(network, v.Options...) { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) + innerIPs = append(innerIPs, result.Prefix().String()) } assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) }) @@ -333,23 +323,37 @@ func TestGeoIPNetworksWithin(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - _, network, err := net.ParseCIDR(v.Network) + prefix, err := netip.ParsePrefix(v.Network) require.NoError(t, err) - n := reader.NetworksWithin(network) var innerIPs []string - for n.Next() { + for result := range reader.NetworksWithin(prefix) { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) + innerIPs = append(innerIPs, result.Prefix().String()) } assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) } } + +func BenchmarkNetworks(b *testing.B) { + db, err := Open(testFile("GeoIP2-Country-Test.mmdb")) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + for r := range db.Networks() { + var rec struct{} + err = r.Decode(&rec) + if err != nil { + b.Error(err) + } + } + } + require.NoError(b, db.Close(), "error on close") +} diff --git a/verifier.go b/verifier.go index b14b3e4..335cb1b 100644 --- a/verifier.go +++ b/verifier.go @@ -102,16 +102,11 @@ func (v *verifier) verifyDatabase() error { func (v *verifier) verifySearchTree() (map[uint]bool, error) { offsets := make(map[uint]bool) - it := v.reader.Networks() - for it.Next() { - offset, err := v.reader.resolveDataPointer(it.lastNode.pointer) - if err != nil { + for result := range v.reader.Networks() { + if err := result.Err(); err != nil { return nil, err } - offsets[uint(offset)] = true - } - if err := it.Err(); err != nil { - return nil, err + offsets[result.offset] = true } return offsets, nil }