Skip to content

Commit

Permalink
DESEC: Fix init (#3017)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Limoncelli <[email protected]>
  • Loading branch information
JenswBE and tlimoncelli authored Jul 1, 2024
1 parent 2f155ce commit c22f20d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 84 deletions.
43 changes: 18 additions & 25 deletions providers/desec/desecProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"strings"

"github.com/StackExchange/dnscontrol/v4/models"
"github.com/StackExchange/dnscontrol/v4/pkg/diff"
Expand All @@ -21,18 +22,10 @@ Info required in `creds.json`:
// NewDeSec creates the provider.
func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
c := &desecProvider{}
c.creds.token = m["auth-token"]
if c.creds.token == "" {
c.token = strings.TrimSpace(m["auth-token"])
if c.token == "" {
return nil, fmt.Errorf("missing deSEC auth-token")
}
if err := c.authenticate(); err != nil {
return nil, fmt.Errorf("authentication failed")
}
//DomainIndex is used for corrections (minttl) and domain creation
if err := c.initializeDomainIndex(); err != nil {
return nil, err
}

return c, nil
}

Expand All @@ -41,7 +34,7 @@ var features = providers.DocumentationNotes{
// See providers/capabilities.go for the entire list of capabilities.
providers.CanAutoDNSSEC: providers.Can("deSEC always signs all records. When trying to disable, a notice is printed."),
providers.CanGetZones: providers.Can(),
providers.CanConcur: providers.Cannot(),
providers.CanConcur: providers.Can(),
providers.CanUseAlias: providers.Unimplemented("Apex aliasing is supported via new SVCB and HTTPS record types. For details, check the deSEC docs."),
providers.CanUseCAA: providers.Can(),
providers.CanUseDS: providers.Can(),
Expand Down Expand Up @@ -119,9 +112,13 @@ func (c *desecProvider) GetZoneRecords(domain string, meta map[string]string) (m

// EnsureZoneExists creates a zone if it does not exist
func (c *desecProvider) EnsureZoneExists(domain string) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if _, ok := c.domainIndex[domain]; ok {
_, ok, err := c.searchDomainIndex(domain)
if err != nil {
return err
}

if ok {
// Domain already exists
return nil
}
return c.createDomain(domain)
Expand Down Expand Up @@ -155,14 +152,14 @@ func PrepDesiredRecords(dc *models.DomainConfig, minTTL uint32) {

// GetZoneRecordsCorrections returns a list of corrections that will turn existing records into dc.Records.
func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, existing models.Records) ([]*models.Correction, error) {
var minTTL uint32
c.mutex.Lock()
if ttl, ok := c.domainIndex[dc.Name]; !ok {
minTTL, ok, err := c.searchDomainIndex(dc.Name)
if err != nil {
return nil, err
}
if !ok {
minTTL = 3600
} else {
minTTL = ttl
}
c.mutex.Unlock()

PrepDesiredRecords(dc, minTTL)

keysToUpdate, toReport, err := diff.NewCompat(dc).ChangedGroups(existing)
Expand Down Expand Up @@ -250,9 +247,5 @@ func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, exist

// ListZones return all the zones in the account
func (c *desecProvider) ListZones() ([]string, error) {
var domains []string
for domain := range c.domainIndex {
domains = append(domains, domain)
}
return domains, nil
return c.listDomainIndex()
}
132 changes: 73 additions & 59 deletions providers/desec/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@ const apiBase = "https://desec.io/api/v1"

// Api layer for desec
type desecProvider struct {
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
creds struct {
tokenid string
token string
user string
password string
}
mutex sync.Mutex
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
domainIndexLock sync.Mutex
token string
}

type domainObject struct {
Expand Down Expand Up @@ -66,86 +61,105 @@ type nonFieldError struct {
Errors []string `json:"non_field_errors"`
}

func (c *desecProvider) authenticate() error {
endpoint := "/auth/account/"
var _, resp, err = c.get(endpoint, "GET")
//restricted tokens are valid, but get 403 on /auth/account
//invalid tokens get 401
if resp.StatusCode == 403 {
return nil
}
if err != nil {
return err
// withDomainIndex checks if the domain index is initialized. If not, it's fetched from the deSEC API.
// Next, the provided readFn function is executed to extract data from the domain index.
func (c *desecProvider) withDomainIndex(readFn func(domainIndex map[string]uint32)) error {
// Lock index
c.domainIndexLock.Lock()
defer c.domainIndexLock.Unlock()

// Init index if needed
if c.domainIndex == nil {
printer.Debugf("Domain index not yet populated, fetching now\n")
var err error
c.domainIndex, err = c.fetchDomainIndex()
if err != nil {
return fmt.Errorf("failed to fetch domain index: %w", err)
}
}

// Execute handler on index
readFn(c.domainIndex)
return nil
}
func (c *desecProvider) initializeDomainIndex() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.domainIndex != nil {
return nil
}

// listDomainIndex lists all the available domains in the domain index
func (c *desecProvider) listDomainIndex() (domains []string, err error) {
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
domains = make([]string, 0, len(domainIndex))
for domain := range domainIndex {
domains = append(domains, domain)
}
})
return
}

// searchDomainIndex performs a lookup to the domain index for the TTL of the domain
func (c *desecProvider) searchDomainIndex(domain string) (ttl uint32, found bool, err error) {
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
ttl, found = domainIndex[domain]
})
return
}

func (c *desecProvider) fetchDomainIndex() (map[string]uint32, error) {
endpoint := "/domains/"
var domainIndex map[string]uint32
var bodyString, resp, err = c.get(endpoint, "GET")
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
//pagination is required
links := c.convertLinks(resp.Header.Get("Link"))
links := convertLinks(resp.Header.Get("Link"))
endpoint = links["first"]
printer.Debugf("initial endpoint %s\n", endpoint)
for endpoint != "" {
bodyString, resp, err = c.get(endpoint, "GET")
if err != nil {
if resp.StatusCode == 404 {
return nil
}
return fmt.Errorf("failed fetching domains: %s", err)
return nil, fmt.Errorf("failed fetching domains: %s", err)
}
err = c.buildIndexFromResponse(bodyString)
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
if err != nil {
return fmt.Errorf("failed fetching domains: %s", err)
return nil, fmt.Errorf("failed fetching domains: %s", err)
}
links = c.convertLinks(resp.Header.Get("Link"))
links = convertLinks(resp.Header.Get("Link"))
endpoint = links["next"]
printer.Debugf("next endpoint %s\n", endpoint)
}
printer.Debugf("Domain Index initilized with pagination (%d domains)\n", len(c.domainIndex))
return nil //domainIndex was build using pagination without errors
printer.Debugf("Domain Index fetched with pagination (%d domains)\n", len(domainIndex))
return domainIndex, nil //domainIndex was build using pagination without errors
}

//no pagination required
if err != nil && resp.StatusCode != 400 {
if resp.StatusCode == 404 {
return nil
}
return fmt.Errorf("failed fetching domains: %s", err)
return nil, fmt.Errorf("failed fetching domains: %s", err)
}
err = c.buildIndexFromResponse(bodyString)
if err == nil {
printer.Debugf("Domain Index initilized without pagination (%d domains)\n", len(c.domainIndex))
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
if err != nil {
return nil, err
}
return err
printer.Debugf("Domain Index fetched without pagination (%d domains)\n", len(domainIndex))
return domainIndex, nil
}

// buildIndexFromResponse takes the bodyString from initializeDomainIndex and builds the domainIndex
func (c *desecProvider) buildIndexFromResponse(bodyString []byte) error {
if c.domainIndex == nil {
c.domainIndex = map[string]uint32{}
}
func appendDomainIndexFromResponse(domainIndex map[string]uint32, bodyString []byte) (map[string]uint32, error) {
var dr []domainObject
err := json.Unmarshal(bodyString, &dr)
if err != nil {
return err
return nil, err
}

if domainIndex == nil {
domainIndex = make(map[string]uint32, len(dr))
}
for _, domain := range dr {
//deSEC allows different minimum ttls per domain
//we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
c.domainIndex[domain.Name] = domain.MinimumTTL
domainIndex[domain.Name] = domain.MinimumTTL
}
return nil
return domainIndex, nil
}

// Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/main/fetch_zone.py#L13)
func (c *desecProvider) convertLinks(links string) map[string]string {
func convertLinks(links string) map[string]string {
mapping := make(map[string]string)
printer.Debugf("Header: %s\n", links)
for _, link := range strings.Split(links, ", ") {
Expand Down Expand Up @@ -173,7 +187,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
//pagination required
links := c.convertLinks(resp.Header.Get("Link"))
links := convertLinks(resp.Header.Get("Link"))
endpoint = links["first"]
printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain))
for endpoint != "" {
Expand All @@ -189,7 +203,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err)
}
rrsNew = append(rrsNew, tmp...)
links = c.convertLinks(resp.Header.Get("Link"))
links = convertLinks(resp.Header.Get("Link"))
endpoint = links["next"]
printer.Debugf("getRecords: next endpoint %s\n", endpoint)
}
Expand Down Expand Up @@ -282,7 +296,7 @@ retry:
client := &http.Client{}
req, _ := http.NewRequest(method, endpoint, nil)
q := req.URL.Query()
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))

req.URL.RawQuery = q.Encode()

Expand All @@ -304,12 +318,12 @@ retry:
if wait > 180 {
return []byte{}, resp, fmt.Errorf("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
time.Sleep(500 * time.Millisecond)
goto retry
}
Expand Down Expand Up @@ -346,7 +360,7 @@ retry:
}
q := req.URL.Query()
if endpoint != "/auth/login/" {
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
}
req.Header.Set("Content-Type", "application/json")

Expand All @@ -371,12 +385,12 @@ retry:
if wait > 180 {
return []byte{}, fmt.Errorf("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
time.Sleep(500 * time.Millisecond)
goto retry
}
Expand Down

0 comments on commit c22f20d

Please sign in to comment.