diff --git a/.editorconfig b/.editorconfig index ebe51d3..9f274b9 100644 --- a/.editorconfig +++ b/.editorconfig @@ -4,9 +4,7 @@ root = true [*] -indent_style = space -indent_size = 2 end_of_line = lf charset = utf-8 -trim_trailing_whitespace = false -insert_final_newline = false \ No newline at end of file +trim_trailing_whitespace = true +insert_final_newline = true diff --git a/cli/cli_test.go b/cli/cli_test.go index e830c2f..4730aa0 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -11,51 +11,64 @@ import ( ) func TestEmpty(t *testing.T) { + t.Parallel() + args := []string{"awl", "-4"} opts, err := cli.ParseCLI(args, "TEST") - assert.NilError(t, err) - assert.Equal(t, opts.Request.Port, 53) assert.Assert(t, opts.IPv4) + assert.Equal(t, opts.Request.Port, 53) } func TestTLSPort(t *testing.T) { + t.Parallel() + args := []string{"awl", "-T"} opts, err := cli.ParseCLI(args, "TEST") - assert.NilError(t, err) assert.Equal(t, opts.Request.Port, 853) } -func TestSubnet(t *testing.T) { - args := []string{"awl", "--subnet", "127.0.0.1/32"} +func TestValidSubnet(t *testing.T) { + t.Parallel() - opts, err := cli.ParseCLI(args, "TEST") + tests := []struct { + args []string + want uint16 + }{ + {[]string{"awl", "--subnet", "127.0.0.1/32"}, uint16(1)}, + {[]string{"awl", "--subnet", "0"}, uint16(1)}, + {[]string{"awl", "--subnet", "::/0"}, uint16(2)}, + } - assert.NilError(t, err) - assert.Equal(t, opts.EDNS.Subnet.Family, uint16(1)) + for _, test := range tests { + test := test - args = []string{"awl", "--subnet", "0"} + t.Run(test.args[2], func(t *testing.T) { + t.Parallel() - opts, err = cli.ParseCLI(args, "TEST") - assert.NilError(t, err) - assert.Equal(t, opts.EDNS.Subnet.Family, uint16(1)) + opts, err := cli.ParseCLI(test.args, "TEST") - args = []string{"awl", "--subnet", "::/0"} + assert.NilError(t, err) + assert.Equal(t, opts.EDNS.Subnet.Family, test.want) + }) + } +} - opts, err = cli.ParseCLI(args, "TEST") - assert.NilError(t, err) - assert.Equal(t, opts.EDNS.Subnet.Family, uint16(2)) +func TestInvalidSubnet(t *testing.T) { + t.Parallel() - args = []string{"awl", "--subnet", "/"} + args := []string{"awl", "--subnet", "/"} - opts, err = cli.ParseCLI(args, "TEST") + _, err := cli.ParseCLI(args, "TEST") assert.ErrorContains(t, err, "EDNS subnet") } func TestMBZ(t *testing.T) { + t.Parallel() + args := []string{"awl", "--zflag", "G"} _, err := cli.ParseCLI(args, "TEST") @@ -64,6 +77,8 @@ func TestMBZ(t *testing.T) { } func TestInvalidFlag(t *testing.T) { + t.Parallel() + args := []string{"awl", "--treebug"} _, err := cli.ParseCLI(args, "TEST") @@ -72,6 +87,8 @@ func TestInvalidFlag(t *testing.T) { } func TestInvalidDig(t *testing.T) { + t.Parallel() + args := []string{"awl", "+a"} _, err := cli.ParseCLI(args, "TEST") @@ -80,6 +97,8 @@ func TestInvalidDig(t *testing.T) { } func TestVersion(t *testing.T) { + t.Parallel() + args := []string{"awl", "--version"} _, err := cli.ParseCLI(args, "test") @@ -88,6 +107,8 @@ func TestVersion(t *testing.T) { } func TestTimeout(t *testing.T) { + t.Parallel() + args := [][]string{ {"awl", "+timeout=0"}, {"awl", "--timeout", "0"}, @@ -95,14 +116,20 @@ func TestTimeout(t *testing.T) { for _, test := range args { test := test - opt, err := cli.ParseCLI(test, "TEST") + t.Run(test[1], func(t *testing.T) { + t.Parallel() - assert.NilError(t, err) - assert.Equal(t, opt.Request.Timeout, time.Second/2) + opt, err := cli.ParseCLI(test, "TEST") + + assert.NilError(t, err) + assert.Equal(t, opt.Request.Timeout, time.Second/2) + }) } } func TestRetries(t *testing.T) { + t.Parallel() + args := [][]string{ {"awl", "+retry=-2"}, {"awl", "+tries=-2"}, @@ -111,10 +138,14 @@ func TestRetries(t *testing.T) { for _, test := range args { test := test - opt, err := cli.ParseCLI(test, "TEST") + t.Run(test[1], func(t *testing.T) { + t.Parallel() - assert.NilError(t, err) - assert.Equal(t, opt.Request.Retries, 0) + opt, err := cli.ParseCLI(test, "TEST") + + assert.NilError(t, err) + assert.Equal(t, opt.Request.Retries, 0) + }) } } diff --git a/cli/misc.go b/cli/misc.go index 7a14854..b181226 100644 --- a/cli/misc.go +++ b/cli/misc.go @@ -45,7 +45,7 @@ func ParseMiscArgs(args []string, opts *util.Options) error { opts.Logger.Info("DNSCrypt implicitly set") case strings.HasPrefix(arg, "tcp://"): opts.TCP = true - opts.Request.Server = strings.TrimPrefix(arg, "udp://") + opts.Request.Server = strings.TrimPrefix(arg, "tcp://") opts.Logger.Info("TCP implicitly set") case strings.HasPrefix(arg, "udp://"): opts.Request.Server = strings.TrimPrefix(arg, "udp://") diff --git a/cli/misc_test.go b/cli/misc_test.go index c3f60c3..e6ac236 100644 --- a/cli/misc_test.go +++ b/cli/misc_test.go @@ -3,7 +3,6 @@ package cli_test import ( - "strconv" "testing" "git.froth.zone/sam/awl/cli" @@ -87,7 +86,7 @@ func TestDefaultServer(t *testing.T) { in string want string }{ - {"DNSCRYPT", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}, + {"DNSCrypt", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}, {"TLS", "dns.google"}, {"HTTPS", "https://dns.cloudflare.com/dns-query"}, {"QUIC", "dns.adguard.com"}, @@ -95,13 +94,14 @@ func TestDefaultServer(t *testing.T) { for _, test := range tests { test := test + t.Run(test.in, func(t *testing.T) { t.Parallel() args := []string{} opts := new(util.Options) opts.Logger = util.InitLogger(0) switch test.in { - case "DNSCRYPT": + case "DNSCrypt": opts.DNSCrypt = true case "TLS": opts.TLS = true @@ -121,38 +121,48 @@ func TestFlagSetting(t *testing.T) { t.Parallel() tests := []struct { - in []string + in string + expected string + over string }{ - {[]string{"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}}, - {[]string{"@tls://dns.google"}}, - {[]string{"@https://dns.cloudflare.com/dns-query"}}, - {[]string{"@quic://dns.adguard.com"}}, - {[]string{"@tcp://dns.froth.zone"}}, - {[]string{"@udp://dns.example.com"}}, + {"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "DNSCrypt"}, + {"@tls://dns.google", "dns.google", "TLS"}, + {"@https://dns.cloudflare.com/dns-query", "https://dns.cloudflare.com/dns-query", "HTTPS"}, + {"@quic://dns.adguard.com", "dns.adguard.com", "QUIC"}, + {"@tcp://dns.froth.zone", "dns.froth.zone", "TCP"}, + {"@udp://dns.example.com", "dns.example.com", "UDP"}, } - for i, test := range tests { + for _, test := range tests { test := test - i := i - t.Run(strconv.Itoa(i), func(t *testing.T) { + + t.Run(test.over, func(t *testing.T) { + t.Parallel() + opts := new(util.Options) opts.Logger = util.InitLogger(0) - t.Parallel() - err := cli.ParseMiscArgs(test.in, opts) + + err := cli.ParseMiscArgs([]string{test.in}, opts) assert.NilError(t, err) - switch i { - case 0: + switch test.over { + case "DNSCrypt": assert.Assert(t, opts.DNSCrypt) - case 1: + assert.Equal(t, opts.Request.Server, test.expected) + case "TLS": assert.Assert(t, opts.TLS) - case 2: + assert.Equal(t, opts.Request.Server, test.expected) + case "HTTPS": assert.Assert(t, opts.HTTPS) - case 3: + assert.Equal(t, opts.Request.Server, test.expected) + case "QUIC": assert.Assert(t, opts.QUIC) - case 4: + assert.Equal(t, opts.Request.Server, test.expected) + case "TCP": assert.Assert(t, opts.TCP) - case 5: + assert.Equal(t, opts.Request.Server, test.expected) + case "UDP": assert.Assert(t, true) + assert.Equal(t, opts.Request.Server, test.expected) } }) } diff --git a/conf/plan9_test.go b/conf/plan9_test.go index 4cf2410..dbf01f9 100644 --- a/conf/plan9_test.go +++ b/conf/plan9_test.go @@ -4,53 +4,22 @@ package conf_test import ( + "runtime" "testing" "git.froth.zone/sam/awl/conf" "gotest.tools/v3/assert" ) -func TestGetPlan9Config(t *testing.T) { +func TestPlan9Config(t *testing.T) { t.Parallel() + if runtime.GOOS != "plan9" { t.Skip("Not running Plan 9, skipping") } - ndbs := []struct { - in string - want string - }{ - {`ip=192.168.122.45 ipmask=255.255.255.0 ipgw=192.168.122.1 - sys=chog9 - dns=192.168.122.1`, "192.168.122.1"}, - {`ipnet=murray-hill ip=135.104.0.0 ipmask=255.255.0.0 - dns=135.104.10.1 - ntp=ntp.cs.bell-labs.com - ipnet=plan9 ip=135.104.9.0 ipmask=255.255.255.0 - ntp=oncore.cs.bell-labs.com - smtp=smtp1.cs.bell-labs.com - ip=135.104.9.6 sys=anna dom=anna.cs.bell-labs.com - smtp=smtp2.cs.bell-labs.com`, "135.104.10.1"}, - } + conf, err := conf.GetDNSConfig() - for _, ndb := range ndbs { - // Go is a little quirky - ndb := ndb - t.Run(ndb.want, func(t *testing.T) { - t.Parallel() - act, err := conf.GetPlan9Config(ndb.in) - assert.NilError(t, err) - assert.Equal(t, ndb.want, act.Servers[0]) - }) - } - - invalid := `sys = spindle - dom=spindle.research.bell-labs.com - bootf=/mips/9powerboot - ip=135.104.117.32 ether=080069020677 - proto=il` - - act, err := conf.GetPlan9Config(invalid) - assert.ErrorContains(t, err, "no DNS servers found") - assert.Assert(t, act == nil) + assert.NilError(t, err) + assert.Assert(t, len(conf.Servers) != 0) } diff --git a/conf/unix_test.go b/conf/unix_test.go index 9f50b6d..41643a1 100644 --- a/conf/unix_test.go +++ b/conf/unix_test.go @@ -13,8 +13,10 @@ import ( "gotest.tools/v3/assert" ) -func TestNonWinConfig(t *testing.T) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { +func TestUnixConfig(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "js" || runtime.GOOS == "zos" { t.Skip("Not running Unix-like, skipping") } diff --git a/main_test.go b/main_test.go index 57186c9..58f8e07 100644 --- a/main_test.go +++ b/main_test.go @@ -3,31 +3,35 @@ package main import ( - "os" "testing" "github.com/stefansundin/go-zflag" "gotest.tools/v3/assert" ) -func TestMain(t *testing.T) { //nolint: paralleltest // Race conditions - os.Stdout = os.NewFile(0, os.DevNull) - os.Stderr = os.NewFile(0, os.DevNull) +func TestRun(t *testing.T) { + t.Parallel() - args := []string{"awl", "+yaml", "@1.1.1.1"} + args := [][]string{ + {"awl", "+yaml", "@1.1.1.1"}, + {"awl", "+short", "@1.1.1.1"}, + } - _, code, err := run(args) - assert.NilError(t, err) - assert.Equal(t, code, 0) + for _, test := range args { + test := test - args = []string{"awl", "+short", "@1.1.1.1"} - - _, code, err = run(args) - assert.NilError(t, err) - assert.Equal(t, code, 0) + t.Run("", func(t *testing.T) { + t.Parallel() + _, code, err := run(test) + assert.NilError(t, err) + assert.Equal(t, code, 0) + }) + } } func TestHelp(t *testing.T) { + t.Parallel() + args := []string{"awl", "-h"} _, code, err := run(args) diff --git a/query/DNSCrypt_test.go b/query/DNSCrypt_test.go index 37df392..a1e7294 100644 --- a/query/DNSCrypt_test.go +++ b/query/DNSCrypt_test.go @@ -15,42 +15,49 @@ func TestDNSCrypt(t *testing.T) { t.Parallel() tests := []struct { - opt util.Options + name string + opts util.Options }{ { + "Valid", util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, Request: util.Request{ - Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - Type: dns.TypeA, - Name: "example.com.", + Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, }, }, }, { + "Valid (TCP)", util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, TCP: true, IPv4: true, Request: util.Request{ - Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 3, }, }, }, { + "Invalid", util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, TCP: true, IPv4: true, Request: util.Request{ - Server: "QMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "QMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 0, }, }, }, @@ -59,10 +66,10 @@ func TestDNSCrypt(t *testing.T) { for _, test := range tests { test := test - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { t.Parallel() - res, err := query.CreateQuery(test.opt) + res, err := query.CreateQuery(test.opts) if err == nil { assert.Assert(t, res != util.Response{}) } else { diff --git a/query/HTTPS.go b/query/HTTPS.go index 0c56002..4b78bee 100644 --- a/query/HTTPS.go +++ b/query/HTTPS.go @@ -52,7 +52,7 @@ func (r *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { } if res.StatusCode != http.StatusOK { - return util.Response{}, &errHTTPStatus{res.StatusCode} + return util.Response{}, &ErrHTTPStatus{res.StatusCode} } r.opts.Logger.Debug("https: reading response") @@ -79,10 +79,11 @@ func (r *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { return resp, nil } -type errHTTPStatus struct { +// ErrHTTPStatus is returned when DoH returns a bad status code. +type ErrHTTPStatus struct { code int } -func (e *errHTTPStatus) Error() string { +func (e *ErrHTTPStatus) Error() string { return fmt.Sprintf("doh server responded with HTTP %d", e.code) } diff --git a/query/HTTPS_test.go b/query/HTTPS_test.go index b844847..ce9473a 100644 --- a/query/HTTPS_test.go +++ b/query/HTTPS_test.go @@ -3,6 +3,7 @@ package query_test import ( + "errors" "testing" "git.froth.zone/sam/awl/query" @@ -11,105 +12,87 @@ import ( "gotest.tools/v3/assert" ) -func TestResolveHTTPS(t *testing.T) { +func TestHTTPS(t *testing.T) { t.Parallel() - var err error - - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "https://dns9.quad9.net/dns-query", - Type: dns.TypeA, - Name: "git.froth.zone.", + tests := []struct { + name string + opts util.Options + }{ + { + "Good", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "https://dns9.quad9.net/dns-query", + Type: dns.TypeA, + Name: "git.froth.zone.", + Retries: 3, + }, + }, + }, + { + "404", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "https://dns9.quad9.net/dns", + Type: dns.TypeA, + Name: "git.froth.zone.", + }, + }, + }, + { + "Bad request domain", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns9.quad9.net/dns-query", + Type: dns.TypeA, + Name: "git.froth.zone", + }, + }, + }, + { + "Bad server domain", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns9..quad9.net/dns-query", + Type: dns.TypeA, + Name: "git.froth.zone.", + }, + }, }, } - // testCase := util.Request{Server: "https://dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."} - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) -} + for _, test := range tests { + test := test -func Test2ResolveHTTPS(t *testing.T) { - t.Parallel() + t.Run(test.name, func(t *testing.T) { + t.Parallel() - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."}, + resolver, err := query.LoadResolver(test.opts) + assert.NilError(t, err) + + msg := new(dns.Msg) + msg.SetQuestion(test.opts.Request.Name, test.opts.Request.Type) + // msg = msg.SetQuestion(testCase.Name, testCase.Type) + res, err := resolver.LookUp(msg) + + if err == nil { + assert.NilError(t, err) + assert.Assert(t, res != util.Response{}) + } else { + if errors.Is(err, &query.ErrHTTPStatus{}) { + assert.ErrorContains(t, err, "404") + } + assert.Equal(t, res, util.Response{}) + } + }) } - - var err error - - testCase := util.Request{Type: dns.TypeA, Name: "git.froth.zone"} - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(testCase.Name, testCase.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "fully qualified") - assert.Equal(t, res, util.Response{}) -} - -func Test3ResolveHTTPS(t *testing.T) { - t.Parallel() - - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "dns9..quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."}, - } - - var err error - - // testCase := - // if the domain is not canonical, make it canonical - // if !strings.HasSuffix(testCase.Name, ".") { - // testCase.Name = fmt.Sprintf("%s.", testCase.Name) - // } - - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - // msg.SetQuestion(testCase.Name, testCase.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "doh: HTTP request") - assert.Equal(t, res, util.Response{}) -} - -func Test404ResolveHTTPS(t *testing.T) { - t.Parallel() - - var err error - - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "https://dns9.quad9.net/dns", - Type: dns.TypeA, - Name: "git.froth.zone.", - }, - } - // testCase := util.Request{Server: "https://dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."} - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "404") - assert.Equal(t, res, util.Response{}) } diff --git a/query/QUIC_test.go b/query/QUIC_test.go index 6d83199..c30ead3 100644 --- a/query/QUIC_test.go +++ b/query/QUIC_test.go @@ -3,10 +3,6 @@ package query_test import ( - "fmt" - "net" - "strconv" - "strings" "testing" "time" @@ -19,72 +15,90 @@ import ( func TestQuic(t *testing.T) { t.Parallel() - opts := util.Options{ - QUIC: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "dns.adguard.com", Port: 853}, + tests := []struct { + name string + opts util.Options + }{ + { + "Valid", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns.adguard.com", + Type: dns.TypeNS, + Port: 853, + Timeout: 750 * time.Millisecond, + Retries: 3, + }, + }, + }, + { + "Bad domain", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns.//./,,adguard\a.com", + Port: 853, + Type: dns.TypeA, + Name: "git.froth.zone", + Timeout: 100 * time.Millisecond, + Retries: 0, + }, + }, + }, + { + "Not canonical", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns.adguard.com", + Port: 853, + Type: dns.TypeA, + Name: "git.froth.zone", + Timeout: 100 * time.Millisecond, + Retries: 0, + }, + }, + }, + { + "Invalid query domain", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "example.com", + Port: 853, + Type: dns.TypeA, + Name: "git.froth.zone", + Timeout: 10 * time.Millisecond, + }, + }, + }, } - testCase := util.Request{Server: "dns.//./,,adguard.com", Type: dns.TypeA, Name: "git.froth.zone"} - testCase2 := util.Request{Server: "dns.adguard.com", Type: dns.TypeA, Name: "git.froth.zone"} - var testCases []util.Request + for _, test := range tests { + test := test - testCases = append(testCases, testCase) - testCases = append(testCases, testCase2) + t.Run(test.name, func(t *testing.T) { + t.Parallel() - for i := range testCases { - switch i { - case 0: - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - // if the domain is not canonical, make it canonical - if !strings.HasSuffix(testCase.Name, ".") { - testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name) - } - - msg := new(dns.Msg) - msg.SetQuestion(testCase.Name, testCase.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - - assert.ErrorContains(t, err, "fully qualified") - assert.Equal(t, res, util.Response{}) - case 1: - resolver, err := query.LoadResolver(opts) + resolver, err := query.LoadResolver(test.opts) assert.NilError(t, err) - testCase2.Server = net.JoinHostPort(testCase2.Server, strconv.Itoa(opts.Request.Port)) - - // if the domain is not canonical, make it canonical - if !strings.HasSuffix(testCase2.Name, ".") { - testCase2.Name = fmt.Sprintf("%s.", testCase2.Name) - } - msg := new(dns.Msg) - msg.SetQuestion(testCase2.Name, testCase2.Type) + msg.SetQuestion(test.opts.Request.Name, test.opts.Request.Type) res, err := resolver.LookUp(msg) - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) - } + if err == nil { + assert.NilError(t, err) + assert.Assert(t, res != util.Response{}) + } else { + assert.Assert(t, res == util.Response{}) + } + }) } } - -func TestInvalidQuic(t *testing.T) { - t.Parallel() - - opts := util.Options{ - QUIC: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "example.com", Port: 853, Type: dns.TypeA, Name: "git.froth.zone", Timeout: 10 * time.Millisecond}, - } - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "timeout") - assert.Equal(t, res, util.Response{}) -} diff --git a/query/general_test.go b/query/general_test.go index 396e87b..5068be1 100644 --- a/query/general_test.go +++ b/query/general_test.go @@ -3,6 +3,7 @@ package query_test import ( + "os" "testing" "time" @@ -15,90 +16,95 @@ import ( func TestResolve(t *testing.T) { t.Parallel() - opts := util.Options{ - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "8.8.4.1", - Port: 1, - Type: dns.TypeA, - Name: "example.com.", - Timeout: time.Second / 2, - Retries: 0, - }, - } - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - - _, err = resolver.LookUp(msg) - assert.ErrorContains(t, err, "timeout") -} - -func TestTruncate(t *testing.T) { - t.Parallel() - - opts := util.Options{ - Logger: util.InitLogger(0), - IPv4: true, - Request: util.Request{ - Server: "madns.binarystar.systems", - Port: 5301, - Type: dns.TypeTXT, - Name: "limit.txt.example.", - }, - } - resolver, err := query.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - res, err := resolver.LookUp(msg) - - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) -} - -func TestResolveAgain(t *testing.T) { - t.Parallel() - tests := []struct { - opt util.Options + name string + opts util.Options }{ { + "UDP", + util.Options{ + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 3, + }, + }, + }, + { + "UDP (Bad Cookie)", + util.Options{ + Logger: util.InitLogger(0), + BadCookie: false, + Request: util.Request{ + Server: "b.root-servers.net", + Port: 53, + Type: dns.TypeNS, + Name: "example.com.", + Retries: 3, + }, + EDNS: util.EDNS{ + EnableEDNS: true, + Cookie: true, + }, + }, + }, + { + "UDP (Truncated)", + util.Options{ + Logger: util.InitLogger(0), + IPv4: true, + Request: util.Request{ + Server: "madns.binarystar.systems", + Port: 5301, + Type: dns.TypeTXT, + Name: "limit.txt.example.", + Retries: 3, + }, + }, + }, + { + "TCP", util.Options{ Logger: util.InitLogger(0), TCP: true, Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeA, - Name: "example.com.", - }, - }, - }, - { - util.Options{ - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, }, }, }, { + "TLS", util.Options{ Logger: util.InitLogger(0), TLS: true, Request: util.Request{ - Server: "dns.google", - Port: 853, - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "dns.google", + Port: 853, + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 3, + }, + }, + }, + { + "Timeout", + util.Options{ + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "8.8.4.1", + Port: 1, + Type: dns.TypeA, + Name: "example.com.", + Timeout: time.Millisecond * 100, + Retries: 0, }, }, }, @@ -107,11 +113,16 @@ func TestResolveAgain(t *testing.T) { for _, test := range tests { test := test - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { t.Parallel() - res, err := query.CreateQuery(test.opt) - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) + + res, err := query.CreateQuery(test.opts) + if err == nil { + assert.NilError(t, err) + assert.Assert(t, res != util.Response{}) + } else { + assert.ErrorIs(t, err, os.ErrDeadlineExceeded) + } }) } } diff --git a/query/print_test.go b/query/print_test.go index 3f88e3c..1dd0226 100644 --- a/query/print_test.go +++ b/query/print_test.go @@ -77,8 +77,7 @@ func TestRealPrint(t *testing.T) { Type: dns.StringToType["NS"], Class: 1, Name: "google.com.", - Timeout: 0, - Retries: 0, + Retries: 3, }, EDNS: util.EDNS{ EnableEDNS: false, @@ -99,7 +98,7 @@ func TestRealPrint(t *testing.T) { Authority: true, Additional: true, Statistics: true, - UcodeTranslate: false, + UcodeTranslate: true, TTL: true, HumanTTL: true, ShowQuery: true, @@ -110,8 +109,7 @@ func TestRealPrint(t *testing.T) { Type: dns.StringToType["NS"], Class: 1, Name: "freecumextremist.com.", - Timeout: 0, - Retries: 0, + Retries: 3, }, EDNS: util.EDNS{ EnableEDNS: false, @@ -172,8 +170,7 @@ func TestRealPrint(t *testing.T) { Type: dns.StringToType["A"], Class: 1, Name: "froth.zone.", - Timeout: 0, - Retries: 0, + Retries: 3, }, EDNS: util.EDNS{ EnableEDNS: true, diff --git a/query/query_test.go b/query/query_test.go index 5a5183f..51cf527 100644 --- a/query/query_test.go +++ b/query/query_test.go @@ -14,115 +14,128 @@ import ( func TestCreateQ(t *testing.T) { t.Parallel() - in := []util.Options{ + tests := []struct { + name string + opts util.Options + }{ { - Logger: util.InitLogger(0), - HeaderFlags: util.HeaderFlags{ - Z: true, - }, - - YAML: true, - - Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeA, - Name: "example.com.", - }, - Display: util.Display{ - Comments: true, - Question: true, - Opt: true, - Answer: true, - Authority: true, - Additional: true, - Statistics: true, - ShowQuery: true, - }, - EDNS: util.EDNS{ - ZFlag: 1, - BufSize: 1500, - EnableEDNS: true, - Cookie: true, - DNSSEC: true, - Expire: true, - KeepOpen: true, - Nsid: true, - Padding: true, - Version: 0, + "1", + util.Options{ + Logger: util.InitLogger(0), + HeaderFlags: util.HeaderFlags{ + Z: true, + }, + YAML: true, + Request: util.Request{ + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, + }, + Display: util.Display{ + Comments: true, + Question: true, + Opt: true, + Answer: true, + Authority: true, + Additional: true, + Statistics: true, + ShowQuery: true, + }, + EDNS: util.EDNS{ + ZFlag: 1, + BufSize: 1500, + EnableEDNS: true, + Cookie: true, + DNSSEC: true, + Expire: true, + KeepOpen: true, + Nsid: true, + Padding: true, + Version: 0, + }, }, }, { - Logger: util.InitLogger(0), - HeaderFlags: util.HeaderFlags{ - Z: true, - }, - XML: true, + "2", + util.Options{ + Logger: util.InitLogger(0), + HeaderFlags: util.HeaderFlags{ + Z: true, + }, + XML: true, - Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeA, - Name: "example.com.", - }, - Display: util.Display{ - Comments: true, - Question: true, - Opt: true, - Answer: true, - Authority: true, - Additional: true, - Statistics: true, - UcodeTranslate: true, - ShowQuery: true, + Request: util.Request{ + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, + }, + Display: util.Display{ + Comments: true, + Question: true, + Opt: true, + Answer: true, + Authority: true, + Additional: true, + Statistics: true, + UcodeTranslate: true, + ShowQuery: true, + }, }, }, { - Logger: util.InitLogger(0), - JSON: true, - QUIC: true, + "3", + util.Options{ + Logger: util.InitLogger(0), + JSON: true, + QUIC: true, - Request: util.Request{ - Server: "dns.adguard.com", - Port: 853, - Type: dns.TypeA, - Name: "example.com.", - }, - Display: util.Display{ - Comments: true, - Question: true, - Opt: true, - Answer: true, - Authority: true, - Additional: true, - Statistics: true, - ShowQuery: true, - }, - EDNS: util.EDNS{ - EnableEDNS: true, - DNSSEC: true, - Cookie: true, - Expire: true, - Nsid: true, + Request: util.Request{ + Server: "dns.adguard.com", + Port: 853, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, + }, + Display: util.Display{ + Comments: true, + Question: true, + Opt: true, + Answer: true, + Authority: true, + Additional: true, + Statistics: true, + ShowQuery: true, + }, + EDNS: util.EDNS{ + EnableEDNS: true, + DNSSEC: true, + Cookie: true, + Expire: true, + Nsid: true, + }, }, }, } - for _, opt := range in { - opt := opt + for _, test := range tests { + test := test - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { t.Parallel() - res, err := query.CreateQuery(opt) + res, err := query.CreateQuery(test.opts) assert.NilError(t, err) assert.Assert(t, res != util.Response{}) - str, err := query.PrintSpecial(res, opt) + str, err := query.PrintSpecial(res, test.opts) assert.NilError(t, err) assert.Assert(t, str != "") - str, err = query.ToString(res, opt) + str, err = query.ToString(res, test.opts) assert.NilError(t, err) assert.Assert(t, str != "") }) diff --git a/util/options_test.go b/util/options_test.go index 741e7c2..6ec6a0c 100644 --- a/util/options_test.go +++ b/util/options_test.go @@ -17,21 +17,20 @@ func TestSubnet(t *testing.T) { "::0/0", "0", "127.0.0.1/32", + "Invalid", } for _, test := range subnet { test := test + t.Run(test, func(t *testing.T) { t.Parallel() err := util.ParseSubnet(test, new(util.Options)) - assert.NilError(t, err) + if err != nil { + assert.ErrorContains(t, err, "invalid CIDR address") + } else { + assert.NilError(t, err) + } }) } } - -func TestInvalidSub(t *testing.T) { - t.Parallel() - - err := util.ParseSubnet("1", new(util.Options)) - assert.ErrorContains(t, err, "invalid CIDR address") -} diff --git a/util/reverseDNS_test.go b/util/reverseDNS_test.go index 5a77748..e4c1a33 100644 --- a/util/reverseDNS_test.go +++ b/util/reverseDNS_test.go @@ -10,25 +10,43 @@ import ( "gotest.tools/v3/assert" ) -var ( - PTR = dns.StringToType["PTR"] - NAPTR = dns.StringToType["NAPTR"] -) - -func TestIPv4(t *testing.T) { +func TestPTR(t *testing.T) { t.Parallel() - act, err := util.ReverseDNS("8.8.4.4", PTR) - assert.NilError(t, err) - assert.Equal(t, act, "4.4.8.8.in-addr.arpa.", "IPv4 reverse") -} + tests := []struct { + name string + in string + expected string + }{ + { + "IPv4", + "8.8.4.4", "4.4.8.8.in-addr.arpa.", + }, + { + "IPv6", + "2606:4700:4700::1111", "1.1.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.4.0.0.7.4.6.0.6.2.ip6.arpa.", + }, + { + "Inavlid value", + "AAAAA", "", + }, + } -func TestIPv6(t *testing.T) { - t.Parallel() + for _, test := range tests { + test := test - act, err := util.ReverseDNS("2606:4700:4700::1111", PTR) - assert.NilError(t, err) - assert.Equal(t, act, "1.1.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.4.0.0.7.4.6.0.6.2.ip6.arpa.", "IPv6 reverse") + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + act, err := util.ReverseDNS(test.in, dns.StringToType["PTR"]) + if err == nil { + assert.NilError(t, err) + } else { + assert.ErrorContains(t, err, "unrecognized address") + } + assert.Equal(t, act, test.expected) + }) + } } func TestNAPTR(t *testing.T) { @@ -48,23 +66,14 @@ func TestNAPTR(t *testing.T) { test := test t.Run(test.in, func(t *testing.T) { t.Parallel() - act, err := util.ReverseDNS(test.in, NAPTR) + act, err := util.ReverseDNS(test.in, dns.StringToType["NAPTR"]) assert.NilError(t, err) assert.Equal(t, test.want, act) }) } } -func TestInvalid(t *testing.T) { - t.Parallel() - - _, err := util.ReverseDNS("AAAAA", 1) - assert.ErrorContains(t, err, "invalid value AAAAA given") -} - -func TestInvalid2(t *testing.T) { - t.Parallel() - - _, err := util.ReverseDNS("1.0", PTR) - assert.ErrorContains(t, err, "PTR reverse") +func TestInvalidAll(t *testing.T) { + _, err := util.ReverseDNS("q", 15236) + assert.ErrorContains(t, err, "invalid value") }