diff --git a/examples/tailscale-custom-domain-dns.toml b/examples/tailscale-custom-domain-dns.toml index 43a185f..6fe2912 100644 --- a/examples/tailscale-custom-domain-dns.toml +++ b/examples/tailscale-custom-domain-dns.toml @@ -55,3 +55,8 @@ port = 53 [fetcher] # How frequently the server will fetch the list of devices from your organization. interval = "1h" + +[aliases] +# Alias for the root domain without any subdomain +# Example: root = "machine-one" +root = "" diff --git a/server/dns.go b/server/dns.go index 586bffc..b35a83a 100644 --- a/server/dns.go +++ b/server/dns.go @@ -28,11 +28,19 @@ func buildRR(name string, question dns.Question, address netip.Addr, host string var err error if (question.Qtype == dns.TypeA || question.Qtype == dns.TypeANY) && address.Is4() { - rr, err = dns.NewRR(fmt.Sprintf("%s.%s A %s", name, host, address.String())) + if name == "" { + rr, err = dns.NewRR(fmt.Sprintf("%s A %s", host, address.String())) + } else { + rr, err = dns.NewRR(fmt.Sprintf("%s.%s A %s", name, host, address.String())) + } } if (question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypeANY) && address.Is6() { - rr, err = dns.NewRR(fmt.Sprintf("%s.%s AAAA %s", name, host, address.String())) + if name == "" { + rr, err = dns.NewRR(fmt.Sprintf("%s AAAA %s", host, address.String())) + } else { + rr, err = dns.NewRR(fmt.Sprintf("%s.%s AAAA %s", name, host, address.String())) + } } if err != nil { @@ -112,9 +120,28 @@ func makeHandler(readDevices chan ReadDevicesOp, host string) DnsHandler { // Get just the subdomain name from the request name := strings.ReplaceAll(question.Name, "."+host, "") - // Respond if a device with the hostname exists - if device, ok := deviceMap[name]; ok { - rrs := constructResponses(name, device, question, host) + var rrs []dns.RR + if question.Name == host { + // If the hostname is bare, check for a root alias + + if viper.IsSet("aliases.root") { + name := viper.GetString("aliases.root") + if device, ok := deviceMap[name]; ok { + log. + Debug(). + Str("host", name). + Msgf(`Serving root alias for hostname "%s"`, name) + + rrs = constructResponses("", device, question, host) + m.Answer = rrs + } + } + } else if device, ok := deviceMap[name]; ok { + // Respond if a device with the hostname exists + rrs = constructResponses(name, device, question, host) + } + + if len(rrs) != 0 { log.Trace().Any("records", rrs).Msg("Sending records to client") m.Answer = rrs }