refactor of query.go (#27)
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
Includes a fix for #26 Co-authored-by: Sam Therapy <sam@samtherapy.net> Reviewed-on: #27 Co-authored-by: grumbulon <grumbulon@grumbulon.xyz> Co-committed-by: grumbulon <grumbulon@grumbulon.xyz>
This commit is contained in:
parent
85a00bffbf
commit
ed4d74bb96
9 changed files with 291 additions and 150 deletions
29
cli.go
29
cli.go
|
@ -9,7 +9,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"git.froth.zone/sam/awl/conf"
|
||||
"git.froth.zone/sam/awl/util"
|
||||
"git.froth.zone/sam/awl/query"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
@ -148,9 +148,9 @@ func prepareCLI() *cli.App {
|
|||
}
|
||||
|
||||
// Parse the wildcard arguments, drill style
|
||||
func parseArgs(args []string) (util.Answers, error) {
|
||||
func parseArgs(args []string, opts query.Options) (query.Answers, error) {
|
||||
var (
|
||||
resp util.Response
|
||||
resp query.Response
|
||||
err error
|
||||
)
|
||||
for _, arg := range args {
|
||||
|
@ -162,7 +162,7 @@ func parseArgs(args []string) (util.Answers, error) {
|
|||
case strings.Contains(arg, "."):
|
||||
resp.Answers.Name, err = idna.ToUnicode(arg)
|
||||
if err != nil {
|
||||
return util.Answers{}, err
|
||||
return query.Answers{}, err
|
||||
}
|
||||
case ok:
|
||||
// If it's a DNS request, it's a DNS request (obviously)
|
||||
|
@ -171,7 +171,7 @@ func parseArgs(args []string) (util.Answers, error) {
|
|||
//else, assume it's a name
|
||||
resp.Answers.Name, err = idna.ToUnicode(arg)
|
||||
if err != nil {
|
||||
return util.Answers{}, err
|
||||
return query.Answers{}, err
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -193,9 +193,24 @@ func parseArgs(args []string) (util.Answers, error) {
|
|||
if err != nil { // Query Google by default
|
||||
resp.Answers.Server = "8.8.4.4"
|
||||
} else {
|
||||
resp.Answers.Server = resolv.Servers[rand.Intn(len(resolv.Servers))]
|
||||
for _, srv := range resolv.Servers {
|
||||
if opts.IPv4 {
|
||||
if strings.Contains(srv, ".") {
|
||||
resp.Answers.Server = srv
|
||||
break
|
||||
}
|
||||
} else if opts.IPv6 {
|
||||
if strings.Contains(srv, ":") {
|
||||
resp.Answers.Server = srv
|
||||
break
|
||||
}
|
||||
} else {
|
||||
resp.Answers.Server = resolv.Servers[rand.Intn(len(resolv.Servers))]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return util.Answers{Server: resp.Answers.Server, Request: resp.Answers.Request, Name: resp.Answers.Name}, nil
|
||||
return query.Answers{Server: resp.Answers.Server, Request: resp.Answers.Request, Name: resp.Answers.Name}, nil
|
||||
}
|
||||
|
|
50
cli_test.go
50
cli_test.go
|
@ -6,7 +6,7 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"git.froth.zone/sam/awl/util"
|
||||
"git.froth.zone/sam/awl/query"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -21,23 +21,23 @@ func TestApp(t *testing.T) {
|
|||
func TestArgParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
in []string
|
||||
want util.Answers
|
||||
want query.Answers
|
||||
}{
|
||||
{
|
||||
[]string{"@::1", "localhost", "AAAA"},
|
||||
util.Answers{Server: "::1", Request: dns.TypeAAAA, Name: "localhost"},
|
||||
query.Answers{Server: "::1", Request: dns.TypeAAAA, Name: "localhost"},
|
||||
},
|
||||
{
|
||||
[]string{"@1.0.0.1", "google.com"},
|
||||
util.Answers{Server: "1.0.0.1", Request: dns.TypeA, Name: "google.com"},
|
||||
query.Answers{Server: "1.0.0.1", Request: dns.TypeA, Name: "google.com"},
|
||||
},
|
||||
{
|
||||
[]string{"@8.8.4.4"},
|
||||
util.Answers{Server: "8.8.4.4", Request: dns.TypeNS, Name: "."},
|
||||
query.Answers{Server: "8.8.4.4", Request: dns.TypeNS, Name: "."},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
act, err := parseArgs(test.in)
|
||||
act, err := parseArgs(test.in, query.Options{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.want, act)
|
||||
}
|
||||
|
@ -50,6 +50,25 @@ func TestQuery(t *testing.T) {
|
|||
err := app.Run(args)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestNoArgs(t *testing.T) {
|
||||
app := prepareCLI()
|
||||
args := os.Args[0:1]
|
||||
args = append(args, "--no-truncate")
|
||||
err := app.Run(args)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestFlags(t *testing.T) {
|
||||
app := prepareCLI()
|
||||
args := os.Args[0:1]
|
||||
args = append(args, "--debug")
|
||||
args = append(args, "--short")
|
||||
args = append(args, "-4")
|
||||
err := app.Run(args)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestHTTPS(t *testing.T) {
|
||||
app := prepareCLI()
|
||||
args := os.Args[0:1]
|
||||
|
@ -60,6 +79,25 @@ func TestHTTPS(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
app := prepareCLI()
|
||||
args := os.Args[0:1]
|
||||
args = append(args, "-j")
|
||||
args = append(args, "git.froth.zone")
|
||||
err := app.Run(args)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestQUIC(t *testing.T) {
|
||||
app := prepareCLI()
|
||||
args := os.Args[0:1]
|
||||
args = append(args, "-Q")
|
||||
args = append(args, "@dns.adguard.com")
|
||||
args = append(args, "git.froth.zone")
|
||||
err := app.Run(args)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func FuzzCli(f *testing.F) {
|
||||
testcases := []string{"git.froth.zone", "", "!12345", "google.com.edu.org.fr"}
|
||||
for _, tc := range testcases {
|
||||
|
|
162
query.go
162
query.go
|
@ -18,142 +18,142 @@ import (
|
|||
|
||||
func doQuery(c *cli.Context) error {
|
||||
var (
|
||||
err error
|
||||
resp util.Response
|
||||
isHTTPS bool
|
||||
err error
|
||||
)
|
||||
resp.Logger = util.InitLogger(c.Bool("debug")) //init logger
|
||||
resp.Answers, err = parseArgs(c.Args().Slice())
|
||||
// load cli flags into options struct
|
||||
Options := query.Options{
|
||||
Logger: util.InitLogger(c.Bool("debug")),
|
||||
Port: c.Int("port"),
|
||||
IPv4: c.Bool("4"),
|
||||
IPv6: c.Bool("6"),
|
||||
DNSSEC: c.Bool("dnssec"),
|
||||
Short: c.Bool("short"),
|
||||
TCP: c.Bool("tcp"),
|
||||
TLS: c.Bool("tls"),
|
||||
HTTPS: c.Bool("https"),
|
||||
QUIC: c.Bool("quic"),
|
||||
Truncate: c.Bool("no-truncate"),
|
||||
AA: c.Bool("aa"),
|
||||
TC: c.Bool("tc"),
|
||||
Z: c.Bool("z"),
|
||||
CD: c.Bool("cd"),
|
||||
NoRD: c.Bool("no-rd"),
|
||||
NoRA: c.Bool("no-ra"),
|
||||
Reverse: c.Bool("reverse"),
|
||||
Debug: c.Bool("debug"),
|
||||
}
|
||||
Options.Answers, err = parseArgs(c.Args().Slice(), Options)
|
||||
if err != nil {
|
||||
resp.Logger.Error("unable to parse args")
|
||||
Options.Logger.Error("Unable to parse args")
|
||||
return err
|
||||
}
|
||||
port := c.Int("port")
|
||||
|
||||
resp.Logger.Debug("starting awl")
|
||||
// If port is not set, set it
|
||||
if port == 0 {
|
||||
if c.Bool("tls") || c.Bool("quic") {
|
||||
resp.Logger.Debug("setting port to 853")
|
||||
port = 853
|
||||
} else {
|
||||
resp.Logger.Debug("setting port to 53")
|
||||
port = 53
|
||||
}
|
||||
}
|
||||
|
||||
if c.Bool("https") || strings.HasPrefix(resp.Answers.Server, "https://") {
|
||||
// add https:// if it doesn't already exist
|
||||
if !strings.HasPrefix(resp.Answers.Server, "https://") {
|
||||
resp.Answers.Server = "https://" + resp.Answers.Server
|
||||
}
|
||||
isHTTPS = true
|
||||
} else {
|
||||
resp.Answers.Server = net.JoinHostPort(resp.Answers.Server, strconv.Itoa(port))
|
||||
}
|
||||
|
||||
// Process the IP/Phone number so a PTR/NAPTR can be done
|
||||
if c.Bool("reverse") {
|
||||
if dns.TypeToString[resp.Answers.Request] == "A" {
|
||||
resp.Answers.Request = dns.StringToType["PTR"]
|
||||
}
|
||||
resp.Answers.Name, err = util.ReverseDNS(resp.Answers.Name, dns.TypeToString[resp.Answers.Request])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// if the domain is not canonical, make it canonical
|
||||
if !strings.HasSuffix(resp.Answers.Name, ".") {
|
||||
resp.Answers.Name = fmt.Sprintf("%s.", resp.Answers.Name)
|
||||
}
|
||||
resp.Logger.Debug("packing DNS message")
|
||||
msg := new(dns.Msg)
|
||||
|
||||
msg.SetQuestion(resp.Answers.Name, resp.Answers.Request)
|
||||
|
||||
// if the domain is not canonical, make it canonical
|
||||
if !strings.HasSuffix(Options.Answers.Name, ".") {
|
||||
Options.Answers.Name = fmt.Sprintf("%s.", Options.Answers.Name)
|
||||
}
|
||||
msg.SetQuestion(Options.Answers.Name, Options.Answers.Request)
|
||||
// If port is not set, set it
|
||||
if Options.Port == 0 {
|
||||
if Options.TLS || Options.QUIC {
|
||||
Options.Port = 853
|
||||
} else {
|
||||
Options.Port = 53
|
||||
}
|
||||
}
|
||||
Options.Logger.Debug("setting any message flags")
|
||||
// Make this authoritative (does this do anything?)
|
||||
if c.Bool("aa") {
|
||||
if Options.AA {
|
||||
Options.Logger.Debug("making message authorative")
|
||||
msg.Authoritative = true
|
||||
}
|
||||
// Set truncated flag (why)
|
||||
if c.Bool("tc") {
|
||||
if Options.TC {
|
||||
msg.Truncated = true
|
||||
}
|
||||
// Set the zero flag if requested (does nothing)
|
||||
if c.Bool("z") {
|
||||
resp.Logger.Debug("setting message to zero")
|
||||
if Options.Z {
|
||||
Options.Logger.Debug("setting to zero")
|
||||
msg.Zero = true
|
||||
}
|
||||
// Disable DNSSEC validation
|
||||
if c.Bool("cd") {
|
||||
if Options.CD {
|
||||
Options.Logger.Debug("disabling DNSSEC validation")
|
||||
msg.CheckingDisabled = true
|
||||
}
|
||||
// Disable wanting recursion
|
||||
if c.Bool("no-rd") {
|
||||
if Options.NoRD {
|
||||
Options.Logger.Debug("disabling recursion")
|
||||
msg.RecursionDesired = false
|
||||
}
|
||||
// Disable recursion being available (I don't think this does anything)
|
||||
if c.Bool("no-ra") {
|
||||
if Options.NoRA {
|
||||
msg.RecursionAvailable = false
|
||||
}
|
||||
// Set DNSSEC if requested
|
||||
if c.Bool("dnssec") {
|
||||
resp.Logger.Debug("using DNSSEC")
|
||||
if Options.DNSSEC {
|
||||
Options.Logger.Debug("using DNSSEC")
|
||||
msg.SetEdns0(1232, true)
|
||||
}
|
||||
|
||||
var in *dns.Msg
|
||||
resolver, err := query.LoadResolver(Options.Answers.Server, Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if Options.Debug {
|
||||
Options.Logger.SetLevel(3)
|
||||
}
|
||||
|
||||
Options.Logger.Debug("Starting awl")
|
||||
|
||||
var in = Options.Answers.DNS
|
||||
|
||||
// Make the DNS request
|
||||
if isHTTPS {
|
||||
resp.Logger.Debug("resolving DoH query")
|
||||
in, resp.Answers.RTT, err = query.ResolveHTTPS(msg, resp.Answers.Server)
|
||||
} else if c.Bool("quic") {
|
||||
resp.Logger.Debug("resolving DoQ query")
|
||||
in, resp.Answers.RTT, err = query.ResolveQUIC(msg, resp.Answers.Server)
|
||||
if Options.HTTPS {
|
||||
in, Options.Answers.RTT, err = resolver.LookUp(msg)
|
||||
} else if Options.QUIC {
|
||||
in, Options.Answers.RTT, err = resolver.LookUp(msg)
|
||||
} else {
|
||||
|
||||
Options.Answers.Server = net.JoinHostPort(Options.Answers.Server, strconv.Itoa(Options.Port))
|
||||
d := new(dns.Client)
|
||||
|
||||
// Set TCP/UDP, depending on flags
|
||||
if c.Bool("tcp") || c.Bool("tls") {
|
||||
resp.Logger.Debug("using tcp")
|
||||
if Options.TCP || Options.TLS {
|
||||
d.Net = "tcp"
|
||||
} else {
|
||||
resp.Logger.Debug("using udp")
|
||||
Options.Logger.Debug("using udp")
|
||||
d.Net = "udp"
|
||||
}
|
||||
|
||||
// Set IPv4 or IPv6, depending on flags
|
||||
switch {
|
||||
case c.Bool("4"):
|
||||
case Options.IPv4:
|
||||
d.Net += "4"
|
||||
case c.Bool("6"):
|
||||
case Options.IPv6:
|
||||
d.Net += "6"
|
||||
}
|
||||
|
||||
// Add TLS, if requested
|
||||
if c.Bool("tls") {
|
||||
if Options.TLS {
|
||||
d.Net += "-tls"
|
||||
}
|
||||
resp.Logger.Debug("exchanging DNS message")
|
||||
in, resp.Answers.RTT, err = d.Exchange(msg, resp.Answers.Server)
|
||||
|
||||
in, Options.Answers.RTT, err = d.Exchange(msg, Options.Answers.Server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If UDP truncates, use TCP instead (unless truncation is to be ignored)
|
||||
if in.MsgHdr.Truncated && !c.Bool("no-truncate") {
|
||||
if in.MsgHdr.Truncated && !Options.Truncate {
|
||||
fmt.Printf(";; Truncated, retrying with TCP\n\n")
|
||||
d.Net = "tcp"
|
||||
switch {
|
||||
case c.Bool("4"):
|
||||
case Options.IPv4:
|
||||
d.Net += "4"
|
||||
case c.Bool("6"):
|
||||
case Options.IPv4:
|
||||
d.Net += "6"
|
||||
}
|
||||
resp.Logger.Debug("exchanging DNS message")
|
||||
in, resp.Answers.RTT, err = d.Exchange(msg, resp.Answers.Server)
|
||||
in, Options.Answers.RTT, err = d.Exchange(msg, Options.Answers.Server)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,8 +171,8 @@ func doQuery(c *cli.Context) error {
|
|||
if !c.Bool("short") {
|
||||
// Print everything
|
||||
fmt.Println(in)
|
||||
fmt.Println(";; Query time:", resp.Answers.RTT)
|
||||
fmt.Println(";; SERVER:", resp.Answers.Server)
|
||||
fmt.Println(";; Query time:", Options.Answers.RTT)
|
||||
fmt.Println(";; SERVER:", Options.Answers.Server)
|
||||
fmt.Println(";; WHEN:", time.Now().Format(time.RFC1123Z))
|
||||
fmt.Println(";; MSG SIZE rcvd:", in.Len())
|
||||
} else {
|
||||
|
|
|
@ -12,17 +12,21 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Resolve a DNS-over-HTTPS query
|
||||
//
|
||||
// Currently only supports POST requests
|
||||
func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
|
||||
type HTTPSResolver struct {
|
||||
server string
|
||||
opts Options
|
||||
}
|
||||
|
||||
func (r *HTTPSResolver) LookUp(msg *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
var resp Response
|
||||
httpR := &http.Client{}
|
||||
buf, err := msg.Pack()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
r.opts.Logger.Debug("making DoH request")
|
||||
// query := server + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
|
||||
req, err := http.NewRequest("POST", server, bytes.NewBuffer(buf))
|
||||
req, err := http.NewRequest("POST", r.server, bytes.NewBuffer(buf))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoH: %s", err.Error())
|
||||
}
|
||||
|
@ -31,7 +35,7 @@ func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error)
|
|||
|
||||
now := time.Now()
|
||||
res, err := httpR.Do(req)
|
||||
rtt := time.Since(now)
|
||||
resp.Answers.RTT = time.Since(now)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoH HTTP request error: %s", err.Error())
|
||||
|
@ -46,11 +50,12 @@ func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error)
|
|||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoH body read error: %s", err.Error())
|
||||
}
|
||||
response := dns.Msg{}
|
||||
err = response.Unpack(fullRes)
|
||||
resp.DNS = dns.Msg{}
|
||||
r.opts.Logger.Debug("unpacking response")
|
||||
err = resp.DNS.Unpack(fullRes)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoH dns message unpack error: %s", err.Error())
|
||||
}
|
||||
|
||||
return &response, rtt, nil
|
||||
return &resp.DNS, resp.Answers.RTT, nil
|
||||
}
|
||||
|
|
|
@ -11,12 +11,18 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Resolve DNS over QUIC, the hip new standard (for privacy I think, IDK)
|
||||
func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
|
||||
type QUICResolver struct {
|
||||
server string
|
||||
opts Options
|
||||
}
|
||||
|
||||
func (r *QUICResolver) LookUp(msg *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
var resp Response
|
||||
tls := &tls.Config{
|
||||
NextProtos: []string{"doq"},
|
||||
}
|
||||
connection, err := quic.DialAddr(server, tls, nil)
|
||||
r.opts.Logger.Debug("making DoQ request")
|
||||
connection, err := quic.DialAddr(r.server, tls, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
@ -40,7 +46,7 @@ func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
|
|||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
rtt := time.Since(t)
|
||||
resp.Answers.RTT = time.Since(t)
|
||||
|
||||
// Close with error: no error
|
||||
err = connection.CloseWithError(0, "")
|
||||
|
@ -53,11 +59,11 @@ func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
|
|||
return nil, 0, err
|
||||
}
|
||||
|
||||
response := dns.Msg{}
|
||||
err = response.Unpack(fullRes)
|
||||
resp.DNS = dns.Msg{}
|
||||
r.opts.Logger.Debug("unpacking DoQ response")
|
||||
err = resp.DNS.Unpack(fullRes)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return &response, rtt, nil
|
||||
return &resp.DNS, resp.Answers.RTT, nil
|
||||
}
|
||||
|
|
|
@ -14,7 +14,13 @@ import (
|
|||
|
||||
func TestResolveHTTPS(t *testing.T) {
|
||||
var err error
|
||||
testCase := util.Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
opts := Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(false),
|
||||
}
|
||||
testCase := Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
resolver, err := LoadResolver(testCase.Server, opts)
|
||||
|
||||
if !strings.HasPrefix(testCase.Server, "https://") {
|
||||
testCase.Server = "https://" + testCase.Server
|
||||
}
|
||||
|
@ -27,28 +33,36 @@ func TestResolveHTTPS(t *testing.T) {
|
|||
msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
msg = msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
var in *dns.Msg
|
||||
in, testCase.RTT, err = ResolveHTTPS(msg, testCase.Server)
|
||||
in, testCase.RTT, err = resolver.LookUp(msg)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, in)
|
||||
|
||||
}
|
||||
|
||||
func Test2ResolveHTTPS(t *testing.T) {
|
||||
opts := Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(false),
|
||||
}
|
||||
var err error
|
||||
testCase := util.Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
|
||||
testCase := Answers{Server: "dns9.quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
resolver, err := LoadResolver(testCase.Server, opts)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
msg = msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
var in *dns.Msg
|
||||
in, testCase.RTT, err = ResolveHTTPS(msg, testCase.Server)
|
||||
in, testCase.RTT, err = resolver.LookUp(msg)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, in)
|
||||
|
||||
}
|
||||
func Test3ResolveHTTPS(t *testing.T) {
|
||||
opts := Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(false),
|
||||
}
|
||||
var err error
|
||||
testCase := util.Answers{Server: "dns9..quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone."}
|
||||
testCase := Answers{Server: "dns9..quad9.net/dns-query", Request: dns.TypeA, Name: "git.froth.zone."}
|
||||
if !strings.HasPrefix(testCase.Server, "https://") {
|
||||
testCase.Server = "https://" + testCase.Server
|
||||
}
|
||||
|
@ -56,30 +70,33 @@ func Test3ResolveHTTPS(t *testing.T) {
|
|||
if !strings.HasSuffix(testCase.Name, ".") {
|
||||
testCase.Name = fmt.Sprintf("%s.", testCase.Name)
|
||||
}
|
||||
resolver, err := LoadResolver(testCase.Server, opts)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
msg = msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
var in *dns.Msg
|
||||
in, testCase.RTT, err = ResolveHTTPS(msg, testCase.Server)
|
||||
in, testCase.RTT, err = resolver.LookUp(msg)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, in)
|
||||
|
||||
}
|
||||
|
||||
func TestQuic(t *testing.T) {
|
||||
var err error
|
||||
testCase := util.Answers{Server: "dns.adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
testCase2 := util.Answers{Server: "dns.adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
var testCases []util.Answers
|
||||
opts := Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(false),
|
||||
Port: 853,
|
||||
Answers: Answers{Server: "dns.adguard.com"},
|
||||
}
|
||||
testCase := Answers{Server: "dns.//./,,adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
testCase2 := Answers{Server: "dns.adguard.com", Request: dns.TypeA, Name: "git.froth.zone"}
|
||||
var testCases []Answers
|
||||
testCases = append(testCases, testCase)
|
||||
testCases = append(testCases, testCase2)
|
||||
|
||||
for i := range testCases {
|
||||
switch i {
|
||||
case 0:
|
||||
port := 853
|
||||
testCases[i].Server = net.JoinHostPort(testCases[i].Server, strconv.Itoa(port))
|
||||
fmt.Println(testCases[i].Server)
|
||||
resolver, err := LoadResolver(testCases[i].Server, opts)
|
||||
// if the domain is not canonical, make it canonical
|
||||
if !strings.HasSuffix(testCase.Name, ".") {
|
||||
testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name)
|
||||
|
@ -88,21 +105,21 @@ func TestQuic(t *testing.T) {
|
|||
msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
msg = msg.SetQuestion(testCase.Name, testCase.Request)
|
||||
var in *dns.Msg
|
||||
in, testCase.RTT, err = ResolveQUIC(msg, testCase.Server)
|
||||
in, testCase.RTT, err = resolver.LookUp(msg)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, in)
|
||||
case 1:
|
||||
port := 853
|
||||
testCases[i].Server = net.JoinHostPort(testCases[i].Server, strconv.Itoa(port))
|
||||
resolver, err := LoadResolver(testCase2.Server, opts)
|
||||
testCase2.Server = net.JoinHostPort(testCase2.Server, strconv.Itoa(opts.Port))
|
||||
// if the domain is not canonical, make it canonical
|
||||
if !strings.HasSuffix(testCase.Name, ".") {
|
||||
testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name)
|
||||
if !strings.HasSuffix(testCase2.Name, ".") {
|
||||
testCase2.Name = fmt.Sprintf("%s.", testCase2.Name)
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(testCases[i].Name, testCases[i].Request)
|
||||
msg = msg.SetQuestion(testCases[i].Name, testCases[i].Request)
|
||||
msg.SetQuestion(testCase2.Name, testCase2.Request)
|
||||
msg = msg.SetQuestion(testCase2.Name, testCase2.Request)
|
||||
var in *dns.Msg
|
||||
in, testCase.RTT, err = ResolveQUIC(msg, testCases[i].Server)
|
||||
in, testCase.RTT, err = resolver.LookUp(msg)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, in)
|
||||
}
|
||||
|
|
76
query/resolver.go
Normal file
76
query/resolver.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.froth.zone/sam/awl/logawl"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// represent all CLI flags
|
||||
type Options struct {
|
||||
Logger *logawl.Logger
|
||||
|
||||
Port int
|
||||
IPv4 bool
|
||||
IPv6 bool
|
||||
DNSSEC bool
|
||||
Short bool
|
||||
TCP bool
|
||||
TLS bool
|
||||
HTTPS bool
|
||||
QUIC bool
|
||||
Truncate bool
|
||||
AA bool
|
||||
TC bool
|
||||
Z bool
|
||||
CD bool
|
||||
NoRD bool
|
||||
NoRA bool
|
||||
Reverse bool
|
||||
Debug bool
|
||||
Answers Answers
|
||||
}
|
||||
type Response struct {
|
||||
Answers Answers `json:"Response"` // These be DNS query answers
|
||||
DNS dns.Msg
|
||||
}
|
||||
|
||||
// The Answers struct is the basic structure of a DNS request
|
||||
// to be returned to the user upon making a request
|
||||
type Answers struct {
|
||||
Server string `json:"Server"` // The server to make the DNS request from
|
||||
DNS *dns.Msg
|
||||
Request uint16 `json:"Request"` // The type of request
|
||||
Name string `json:"Name"` // The domain name to make a DNS request for
|
||||
RTT time.Duration `json:"RTT"` // The time it took to make the DNS query
|
||||
}
|
||||
|
||||
type Resolver interface {
|
||||
LookUp(*dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
func LoadResolver(server string, opts Options) (Resolver, error) {
|
||||
if opts.HTTPS {
|
||||
opts.Logger.Debug("loading DoH resolver")
|
||||
if !strings.HasPrefix(server, "https://") {
|
||||
server = "https://" + server
|
||||
}
|
||||
return &HTTPSResolver{
|
||||
server: server,
|
||||
opts: opts,
|
||||
}, nil
|
||||
} else if opts.QUIC {
|
||||
opts.Logger.Debug("loading DoQ resolver")
|
||||
server = net.JoinHostPort(opts.Answers.Server, strconv.Itoa(opts.Port))
|
||||
return &QUICResolver{
|
||||
server: server,
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
|
@ -6,26 +6,10 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.froth.zone/sam/awl/logawl"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Answers Answers `json:"Response"` // These be DNS query answers
|
||||
Logger *logawl.Logger
|
||||
}
|
||||
|
||||
// The Answers struct is the basic structure of a DNS request
|
||||
// to be returned to the user upon making a request
|
||||
type Answers struct {
|
||||
Server string `json:"Server"` // The server to make the DNS request from
|
||||
Request uint16 `json:"Request"` // The type of request
|
||||
Name string `json:"Name"` // The domain name to make a DNS request for
|
||||
RTT time.Duration `json:"RTT"` // The time it took to make the DNS query
|
||||
}
|
||||
|
||||
// Given an IP or phone number, return a canonical string to be queried
|
||||
func ReverseDNS(address string, query string) (string, error) {
|
||||
if query == "PTR" {
|
||||
|
|
Loading…
Reference in a new issue