diff --git a/.editorconfig b/.editorconfig index ebe51d3..9f274b9 100644 --- a/.editorconfig +++ b/.editorconfig @@ -4,9 +4,7 @@ root = true [*] -indent_style = space -indent_size = 2 end_of_line = lf charset = utf-8 -trim_trailing_whitespace = false -insert_final_newline = false \ No newline at end of file +trim_trailing_whitespace = true +insert_final_newline = true diff --git a/cli/cli.go b/cli/cli.go index 0833b6c..14a1161 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -18,9 +18,9 @@ import ( // ParseCLI parses arguments given from the CLI and passes them into an `Options` // struct. func ParseCLI(args []string, version string) (util.Options, error) { - flag.CommandLine = flag.NewFlagSet(args[0], flag.ContinueOnError) + flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError) - flag.Usage = func() { + flagSet.Usage = func() { fmt.Println(`awl - drill, writ small Usage: awl name [@server] [record] @@ -31,80 +31,80 @@ func ParseCLI(args []string, version string) (util.Options, error) { Dig-like +[no]commands are also supported, see dig(1) or dig -h Options:`) - flag.PrintDefaults() + flagSet.PrintDefaults() } // CLI flags // // Remember, when adding a flag edit the manpage and the completions :) var ( - port = flag.Int("port", 0, "`port` to make DNS query (default: 53 for UDP/TCP, 853 for TLS/QUIC)", flag.OptShorthand('p'), flag.OptDisablePrintDefault(true)) - query = flag.String("query", "", "domain name to `query` (default: .)", flag.OptShorthand('q')) - class = flag.String("class", "IN", "DNS `class` to query", flag.OptShorthand('c')) - qType = flag.String("qType", "", "`type` to query (default: A)", flag.OptShorthand('t')) + port = flagSet.Int("port", 0, "`port` to make DNS query (default: 53 for UDP/TCP, 853 for TLS/QUIC)", flag.OptShorthand('p'), flag.OptDisablePrintDefault(true)) + query = flagSet.String("query", "", "domain name to `query` (default: .)", flag.OptShorthand('q')) + class = flagSet.String("class", "IN", "DNS `class` to query", flag.OptShorthand('c')) + qType = flagSet.String("qType", "", "`type` to query (default: A)", flag.OptShorthand('t')) - ipv4 = flag.Bool("4", false, "force IPv4", flag.OptShorthand('4')) - ipv6 = flag.Bool("6", false, "force IPv6", flag.OptShorthand('6')) - reverse = flag.Bool("reverse", false, "do a reverse lookup", flag.OptShorthand('x')) + ipv4 = flagSet.Bool("4", false, "force IPv4", flag.OptShorthand('4')) + ipv6 = flagSet.Bool("6", false, "force IPv6", flag.OptShorthand('6')) + reverse = flagSet.Bool("reverse", false, "do a reverse lookup", flag.OptShorthand('x')) - timeout = flag.Float32("timeout", 1, "Timeout, in `seconds`") - retry = flag.Int("retries", 2, "number of `times` to retry") + timeout = flagSet.Float32("timeout", 1, "Timeout, in `seconds`") + retry = flagSet.Int("retries", 2, "number of `times` to retry") - edns = flag.Bool("no-edns", false, "disable EDNS entirely") - ednsVer = flag.Uint8("edns-ver", 0, "set EDNS version") - dnssec = flag.Bool("dnssec", false, "enable DNSSEC", flag.OptShorthand('D')) - expire = flag.Bool("expire", false, "set EDNS expire") - nsid = flag.Bool("nsid", false, "set EDNS NSID", flag.OptShorthand('n')) - cookie = flag.Bool("no-cookie", false, "disable sending EDNS cookie (default: cookie sent)") - tcpKeepAlive = flag.Bool("keep-alive", false, "send EDNS TCP keep-alive") - udpBufSize = flag.Uint16("buffer-size", 1232, "set EDNS UDP buffer size", flag.OptShorthand('b')) - mbzflag = flag.String("zflag", "0", "set EDNS z-flag `value`") - subnet = flag.String("subnet", "", "set EDNS client subnet") - padding = flag.Bool("pad", false, "set EDNS padding") + edns = flagSet.Bool("no-edns", false, "disable EDNS entirely") + ednsVer = flagSet.Uint8("edns-ver", 0, "set EDNS version") + dnssec = flagSet.Bool("dnssec", false, "enable DNSSEC", flag.OptShorthand('D')) + expire = flagSet.Bool("expire", false, "set EDNS expire") + nsid = flagSet.Bool("nsid", false, "set EDNS NSID", flag.OptShorthand('n')) + cookie = flagSet.Bool("no-cookie", false, "disable sending EDNS cookie (default: cookie sent)") + tcpKeepAlive = flagSet.Bool("keep-alive", false, "send EDNS TCP keep-alive") + udpBufSize = flagSet.Uint16("buffer-size", 1232, "set EDNS UDP buffer size", flag.OptShorthand('b')) + mbzflag = flagSet.String("zflag", "0", "set EDNS z-flag `value`") + subnet = flagSet.String("subnet", "", "set EDNS client subnet") + padding = flagSet.Bool("pad", false, "set EDNS padding") - badCookie = flag.Bool("no-bad-cookie", false, "ignore BADCOOKIE EDNS responses (default: retry with correct cookie") - truncate = flag.Bool("no-truncate", false, "ignore truncation if a UDP request truncates (default: retry with TCP)") + badCookie = flagSet.Bool("no-bad-cookie", false, "ignore BADCOOKIE EDNS responses (default: retry with correct cookie") + truncate = flagSet.Bool("no-truncate", false, "ignore truncation if a UDP request truncates (default: retry with TCP)") - tcp = flag.Bool("tcp", false, "use TCP") - dnscrypt = flag.Bool("dnscrypt", false, "use DNSCrypt") - tls = flag.Bool("tls", false, "use DNS-over-TLS", flag.OptShorthand('T')) - https = flag.Bool("https", false, "use DNS-over-HTTPS", flag.OptShorthand('H')) - quic = flag.Bool("quic", false, "use DNS-over-QUIC", flag.OptShorthand('Q')) + tcp = flagSet.Bool("tcp", false, "use TCP") + dnscrypt = flagSet.Bool("dnscrypt", false, "use DNSCrypt") + tls = flagSet.Bool("tls", false, "use DNS-over-TLS", flag.OptShorthand('T')) + https = flagSet.Bool("https", false, "use DNS-over-HTTPS", flag.OptShorthand('H')) + quic = flagSet.Bool("quic", false, "use DNS-over-QUIC", flag.OptShorthand('Q')) - tlsHost = flag.String("tls-host", "", "Server name to use for TLS verification") - noVerify = flag.Bool("tls-no-verify", false, "Disable TLS cert verification") + tlsHost = flagSet.String("tls-host", "", "Server name to use for TLS verification") + noVerify = flagSet.Bool("tls-no-verify", false, "Disable TLS cert verification") - aaflag = flag.Bool("aa", false, "set/unset AA (Authoratative Answer) flag (default: not set)") - adflag = flag.Bool("ad", false, "set/unset AD (Authenticated Data) flag (default: not set)") - cdflag = flag.Bool("cd", false, "set/unset CD (Checking Disabled) flag (default: not set)") - qrflag = flag.Bool("qr", false, "set/unset QR (QueRy) flag (default: not set)") - rdflag = flag.Bool("rd", true, "set/unset RD (Recursion Desired) flag (default: set)", flag.OptDisablePrintDefault(true)) - raflag = flag.Bool("ra", false, "set/unset RA (Recursion Available) flag (default: not set)") - tcflag = flag.Bool("tc", false, "set/unset TC (TrunCated) flag (default: not set)") - zflag = flag.Bool("z", false, "set/unset Z (Zero) flag (default: not set)", flag.OptShorthand('z')) + aaflag = flagSet.Bool("aa", false, "set/unset AA (Authoratative Answer) flag (default: not set)") + adflag = flagSet.Bool("ad", false, "set/unset AD (Authenticated Data) flag (default: not set)") + cdflag = flagSet.Bool("cd", false, "set/unset CD (Checking Disabled) flag (default: not set)") + qrflag = flagSet.Bool("qr", false, "set/unset QR (QueRy) flag (default: not set)") + rdflag = flagSet.Bool("rd", true, "set/unset RD (Recursion Desired) flag (default: set)", flag.OptDisablePrintDefault(true)) + raflag = flagSet.Bool("ra", false, "set/unset RA (Recursion Available) flag (default: not set)") + tcflag = flagSet.Bool("tc", false, "set/unset TC (TrunCated) flag (default: not set)") + zflag = flagSet.Bool("z", false, "set/unset Z (Zero) flag (default: not set)", flag.OptShorthand('z')) - short = flag.Bool("short", false, "print just the results", flag.OptShorthand('s')) - json = flag.Bool("json", false, "print the result(s) as JSON", flag.OptShorthand('j')) - xml = flag.Bool("xml", false, "print the result(s) as XML", flag.OptShorthand('X')) - yaml = flag.Bool("yaml", false, "print the result(s) as yaml", flag.OptShorthand('y')) + short = flagSet.Bool("short", false, "print just the results", flag.OptShorthand('s')) + json = flagSet.Bool("json", false, "print the result(s) as JSON", flag.OptShorthand('j')) + xml = flagSet.Bool("xml", false, "print the result(s) as XML", flag.OptShorthand('X')) + yaml = flagSet.Bool("yaml", false, "print the result(s) as yaml", flag.OptShorthand('y')) - noC = flag.Bool("no-comments", false, "disable printing the comments") - noQ = flag.Bool("no-question", false, "disable printing the question section") - noOpt = flag.Bool("no-opt", false, "disable printing the OPT pseudosection") - noAns = flag.Bool("no-answer", false, "disable printing the answer section") - noAuth = flag.Bool("no-authority", false, "disable printing the authority section") - noAdd = flag.Bool("no-additional", false, "disable printing the additional section") - noStats = flag.Bool("no-statistics", false, "disable printing the statistics section") + noC = flagSet.Bool("no-comments", false, "disable printing the comments") + noQ = flagSet.Bool("no-question", false, "disable printing the question section") + noOpt = flagSet.Bool("no-opt", false, "disable printing the OPT pseudosection") + noAns = flagSet.Bool("no-answer", false, "disable printing the answer section") + noAuth = flagSet.Bool("no-authority", false, "disable printing the authority section") + noAdd = flagSet.Bool("no-additional", false, "disable printing the additional section") + noStats = flagSet.Bool("no-statistics", false, "disable printing the statistics section") - verbosity = flag.Int("verbosity", 1, "sets verbosity `level`", flag.OptShorthand('v'), flag.OptNoOptDefVal("2")) - versionFlag = flag.Bool("version", false, "print version information", flag.OptShorthand('V')) + verbosity = flagSet.Int("verbosity", 1, "sets verbosity `level`", flag.OptShorthand('v'), flag.OptNoOptDefVal("2")) + versionFlag = flagSet.Bool("version", false, "print version information", flag.OptShorthand('V')) ) // Don't sort the flags when -h is given - flag.CommandLine.SortFlags = false + flagSet.SortFlags = false // Parse the flags - if err := flag.CommandLine.Parse(args[1:]); err != nil { + if err := flagSet.Parse(args[1:]); err != nil { return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err) } @@ -194,7 +194,7 @@ func ParseCLI(args []string, version string) (util.Options, error) { // Parse all the arguments that don't start with - or -- // This includes the dig-style (+) options - err = ParseMiscArgs(flag.Args(), &opts) + err = ParseMiscArgs(flagSet.Args(), &opts) if err != nil { return opts, err } diff --git a/cli/cli_test.go b/cli/cli_test.go index e830c2f..4730aa0 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -11,51 +11,64 @@ import ( ) func TestEmpty(t *testing.T) { + t.Parallel() + args := []string{"awl", "-4"} opts, err := cli.ParseCLI(args, "TEST") - assert.NilError(t, err) - assert.Equal(t, opts.Request.Port, 53) assert.Assert(t, opts.IPv4) + assert.Equal(t, opts.Request.Port, 53) } func TestTLSPort(t *testing.T) { + t.Parallel() + args := []string{"awl", "-T"} opts, err := cli.ParseCLI(args, "TEST") - assert.NilError(t, err) assert.Equal(t, opts.Request.Port, 853) } -func TestSubnet(t *testing.T) { - args := []string{"awl", "--subnet", "127.0.0.1/32"} +func TestValidSubnet(t *testing.T) { + t.Parallel() - opts, err := cli.ParseCLI(args, "TEST") + tests := []struct { + args []string + want uint16 + }{ + {[]string{"awl", "--subnet", "127.0.0.1/32"}, uint16(1)}, + {[]string{"awl", "--subnet", "0"}, uint16(1)}, + {[]string{"awl", "--subnet", "::/0"}, uint16(2)}, + } - assert.NilError(t, err) - assert.Equal(t, opts.EDNS.Subnet.Family, uint16(1)) + for _, test := range tests { + test := test - args = []string{"awl", "--subnet", "0"} + t.Run(test.args[2], func(t *testing.T) { + t.Parallel() - opts, err = cli.ParseCLI(args, "TEST") - assert.NilError(t, err) - assert.Equal(t, opts.EDNS.Subnet.Family, uint16(1)) + opts, err := cli.ParseCLI(test.args, "TEST") - args = []string{"awl", "--subnet", "::/0"} + assert.NilError(t, err) + assert.Equal(t, opts.EDNS.Subnet.Family, test.want) + }) + } +} - opts, err = cli.ParseCLI(args, "TEST") - assert.NilError(t, err) - assert.Equal(t, opts.EDNS.Subnet.Family, uint16(2)) +func TestInvalidSubnet(t *testing.T) { + t.Parallel() - args = []string{"awl", "--subnet", "/"} + args := []string{"awl", "--subnet", "/"} - opts, err = cli.ParseCLI(args, "TEST") + _, err := cli.ParseCLI(args, "TEST") assert.ErrorContains(t, err, "EDNS subnet") } func TestMBZ(t *testing.T) { + t.Parallel() + args := []string{"awl", "--zflag", "G"} _, err := cli.ParseCLI(args, "TEST") @@ -64,6 +77,8 @@ func TestMBZ(t *testing.T) { } func TestInvalidFlag(t *testing.T) { + t.Parallel() + args := []string{"awl", "--treebug"} _, err := cli.ParseCLI(args, "TEST") @@ -72,6 +87,8 @@ func TestInvalidFlag(t *testing.T) { } func TestInvalidDig(t *testing.T) { + t.Parallel() + args := []string{"awl", "+a"} _, err := cli.ParseCLI(args, "TEST") @@ -80,6 +97,8 @@ func TestInvalidDig(t *testing.T) { } func TestVersion(t *testing.T) { + t.Parallel() + args := []string{"awl", "--version"} _, err := cli.ParseCLI(args, "test") @@ -88,6 +107,8 @@ func TestVersion(t *testing.T) { } func TestTimeout(t *testing.T) { + t.Parallel() + args := [][]string{ {"awl", "+timeout=0"}, {"awl", "--timeout", "0"}, @@ -95,14 +116,20 @@ func TestTimeout(t *testing.T) { for _, test := range args { test := test - opt, err := cli.ParseCLI(test, "TEST") + t.Run(test[1], func(t *testing.T) { + t.Parallel() - assert.NilError(t, err) - assert.Equal(t, opt.Request.Timeout, time.Second/2) + opt, err := cli.ParseCLI(test, "TEST") + + assert.NilError(t, err) + assert.Equal(t, opt.Request.Timeout, time.Second/2) + }) } } func TestRetries(t *testing.T) { + t.Parallel() + args := [][]string{ {"awl", "+retry=-2"}, {"awl", "+tries=-2"}, @@ -111,10 +138,14 @@ func TestRetries(t *testing.T) { for _, test := range args { test := test - opt, err := cli.ParseCLI(test, "TEST") + t.Run(test[1], func(t *testing.T) { + t.Parallel() - assert.NilError(t, err) - assert.Equal(t, opt.Request.Retries, 0) + opt, err := cli.ParseCLI(test, "TEST") + + assert.NilError(t, err) + assert.Equal(t, opt.Request.Retries, 0) + }) } } diff --git a/cli/misc.go b/cli/misc.go index 3beb09d..799a377 100644 --- a/cli/misc.go +++ b/cli/misc.go @@ -45,7 +45,7 @@ func ParseMiscArgs(args []string, opts *util.Options) error { opts.Logger.Info("DNSCrypt implicitly set") case strings.HasPrefix(arg, "tcp://"): opts.TCP = true - opts.Request.Server = strings.TrimPrefix(arg, "udp://") + opts.Request.Server = strings.TrimPrefix(arg, "tcp://") opts.Logger.Info("TCP implicitly set") case strings.HasPrefix(arg, "udp://"): opts.Request.Server = strings.TrimPrefix(arg, "udp://") diff --git a/cli/misc_test.go b/cli/misc_test.go index 7218418..d2f9a09 100644 --- a/cli/misc_test.go +++ b/cli/misc_test.go @@ -3,7 +3,6 @@ package cli_test import ( - "strconv" "testing" "git.froth.zone/sam/awl/cli" @@ -87,7 +86,7 @@ func TestDefaultServer(t *testing.T) { in string want string }{ - {"DNSCRYPT", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}, + {"DNSCrypt", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}, {"TLS", "dns.google"}, {"HTTPS", "https://dns.cloudflare.com/dns-query"}, {"QUIC", "dns.adguard.com"}, @@ -95,13 +94,14 @@ func TestDefaultServer(t *testing.T) { for _, test := range tests { test := test + t.Run(test.in, func(t *testing.T) { t.Parallel() args := []string{} opts := new(util.Options) opts.Logger = util.InitLogger(0) switch test.in { - case "DNSCRYPT": + case "DNSCrypt": opts.DNSCrypt = true case "TLS": opts.TLS = true @@ -121,38 +121,48 @@ func TestFlagSetting(t *testing.T) { t.Parallel() tests := []struct { - in []string + in string + expected string + over string }{ - {[]string{"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}}, - {[]string{"@tls://dns.google"}}, - {[]string{"@https://dns.cloudflare.com/dns-query"}}, - {[]string{"@quic://dns.adguard.com"}}, - {[]string{"@tcp://dns.froth.zone"}}, - {[]string{"@udp://dns.example.com"}}, + {"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "DNSCrypt"}, + {"@tls://dns.google", "dns.google", "TLS"}, + {"@https://dns.cloudflare.com/dns-query", "https://dns.cloudflare.com/dns-query", "HTTPS"}, + {"@quic://dns.adguard.com", "dns.adguard.com", "QUIC"}, + {"@tcp://dns.froth.zone", "dns.froth.zone", "TCP"}, + {"@udp://dns.example.com", "dns.example.com", "UDP"}, } - for i, test := range tests { + for _, test := range tests { test := test - i := i - t.Run(strconv.Itoa(i), func(t *testing.T) { + + t.Run(test.over, func(t *testing.T) { + t.Parallel() + opts := new(util.Options) opts.Logger = util.InitLogger(0) - t.Parallel() - err := cli.ParseMiscArgs(test.in, opts) + + err := cli.ParseMiscArgs([]string{test.in}, opts) assert.NilError(t, err) - switch i { - case 0: + switch test.over { + case "DNSCrypt": assert.Assert(t, opts.DNSCrypt) - case 1: + assert.Equal(t, opts.Request.Server, test.expected) + case "TLS": assert.Assert(t, opts.TLS) - case 2: + assert.Equal(t, opts.Request.Server, test.expected) + case "HTTPS": assert.Assert(t, opts.HTTPS) - case 3: + assert.Equal(t, opts.Request.Server, test.expected) + case "QUIC": assert.Assert(t, opts.QUIC) - case 4: + assert.Equal(t, opts.Request.Server, test.expected) + case "TCP": assert.Assert(t, opts.TCP) - case 5: + assert.Equal(t, opts.Request.Server, test.expected) + case "UDP": assert.Assert(t, true) + assert.Equal(t, opts.Request.Server, test.expected) } }) } diff --git a/conf/plan9_test.go b/conf/plan9_test.go index 4cf2410..dbf01f9 100644 --- a/conf/plan9_test.go +++ b/conf/plan9_test.go @@ -4,53 +4,22 @@ package conf_test import ( + "runtime" "testing" "git.froth.zone/sam/awl/conf" "gotest.tools/v3/assert" ) -func TestGetPlan9Config(t *testing.T) { +func TestPlan9Config(t *testing.T) { t.Parallel() + if runtime.GOOS != "plan9" { t.Skip("Not running Plan 9, skipping") } - ndbs := []struct { - in string - want string - }{ - {`ip=192.168.122.45 ipmask=255.255.255.0 ipgw=192.168.122.1 - sys=chog9 - dns=192.168.122.1`, "192.168.122.1"}, - {`ipnet=murray-hill ip=135.104.0.0 ipmask=255.255.0.0 - dns=135.104.10.1 - ntp=ntp.cs.bell-labs.com - ipnet=plan9 ip=135.104.9.0 ipmask=255.255.255.0 - ntp=oncore.cs.bell-labs.com - smtp=smtp1.cs.bell-labs.com - ip=135.104.9.6 sys=anna dom=anna.cs.bell-labs.com - smtp=smtp2.cs.bell-labs.com`, "135.104.10.1"}, - } + conf, err := conf.GetDNSConfig() - for _, ndb := range ndbs { - // Go is a little quirky - ndb := ndb - t.Run(ndb.want, func(t *testing.T) { - t.Parallel() - act, err := conf.GetPlan9Config(ndb.in) - assert.NilError(t, err) - assert.Equal(t, ndb.want, act.Servers[0]) - }) - } - - invalid := `sys = spindle - dom=spindle.research.bell-labs.com - bootf=/mips/9powerboot - ip=135.104.117.32 ether=080069020677 - proto=il` - - act, err := conf.GetPlan9Config(invalid) - assert.ErrorContains(t, err, "no DNS servers found") - assert.Assert(t, act == nil) + assert.NilError(t, err) + assert.Assert(t, len(conf.Servers) != 0) } diff --git a/conf/unix_test.go b/conf/unix_test.go index 9f50b6d..41643a1 100644 --- a/conf/unix_test.go +++ b/conf/unix_test.go @@ -13,8 +13,10 @@ import ( "gotest.tools/v3/assert" ) -func TestNonWinConfig(t *testing.T) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { +func TestUnixConfig(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "js" || runtime.GOOS == "zos" { t.Skip("Not running Unix-like, skipping") } diff --git a/go.mod b/go.mod index fd8f97d..1ea65d7 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,8 @@ require ( github.com/lucas-clemente/quic-go v0.29.0 github.com/miekg/dns v1.1.50 github.com/stefansundin/go-zflag v1.1.1 - golang.org/x/net v0.0.0-20220909164309-bea034e7d591 - golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41 + golang.org/x/net v0.0.0-20220921203646-d300de134e69 + golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 gopkg.in/yaml.v3 v3.0.1 gotest.tools/v3 v3.3.0 ) diff --git a/go.sum b/go.sum index f40c575..4a47276 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,16 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220909164309-bea034e7d591 h1:D0B/7al0LLrVC8aWF4+oxpv/m8bc7ViFfVS8/gXGdqI= golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220919171627-f8f703f97925 h1:5Vhms/MnzRfom2tX/lHt58o74fzobhcw2FDjPdoGoow= +golang.org/x/net v0.0.0-20220919171627-f8f703f97925/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220919232410-f2f64ebce3c1 h1:TWZxd/th7FbRSMret2MVQdlI8uT49QEtwZdvJrxjEHU= +golang.org/x/net v0.0.0-20220919232410-f2f64ebce3c1/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220920203100-d0c6ba3f52d9 h1:asZqf0wXastQr+DudYagQS8uBO8bHKeYD1vbAvGmFL8= +golang.org/x/net v0.0.0-20220920203100-d0c6ba3f52d9/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220921155015-db77216a4ee9 h1:SdDGdqRuKrF2R4XGcnPzcvZ63c/55GvhoHUus0o+BNI= +golang.org/x/net v0.0.0-20220921155015-db77216a4ee9/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20220921203646-d300de134e69 h1:hUJpGDpnfwdJW8iNypFjmSY0sCBEL+spFTZ2eO+Sfps= +golang.org/x/net v0.0.0-20220921203646-d300de134e69/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -117,6 +127,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41 h1:ohgcoMbSofXygzo6AD2I1kz3BFmW1QArPYTtwEM3UXc= golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc= +golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/main_test.go b/main_test.go index 57186c9..58f8e07 100644 --- a/main_test.go +++ b/main_test.go @@ -3,31 +3,35 @@ package main import ( - "os" "testing" "github.com/stefansundin/go-zflag" "gotest.tools/v3/assert" ) -func TestMain(t *testing.T) { //nolint: paralleltest // Race conditions - os.Stdout = os.NewFile(0, os.DevNull) - os.Stderr = os.NewFile(0, os.DevNull) +func TestRun(t *testing.T) { + t.Parallel() - args := []string{"awl", "+yaml", "@1.1.1.1"} + args := [][]string{ + {"awl", "+yaml", "@1.1.1.1"}, + {"awl", "+short", "@1.1.1.1"}, + } - _, code, err := run(args) - assert.NilError(t, err) - assert.Equal(t, code, 0) + for _, test := range args { + test := test - args = []string{"awl", "+short", "@1.1.1.1"} - - _, code, err = run(args) - assert.NilError(t, err) - assert.Equal(t, code, 0) + t.Run("", func(t *testing.T) { + t.Parallel() + _, code, err := run(test) + assert.NilError(t, err) + assert.Equal(t, code, 0) + }) + } } func TestHelp(t *testing.T) { + t.Parallel() + args := []string{"awl", "-h"} _, code, err := run(args) diff --git a/pkg/query/print_test.go b/pkg/query/print_test.go index c0bb505..3173d7f 100644 --- a/pkg/query/print_test.go +++ b/pkg/query/print_test.go @@ -77,8 +77,7 @@ func TestRealPrint(t *testing.T) { Type: dns.StringToType["NS"], Class: 1, Name: "google.com.", - Timeout: 0, - Retries: 0, + Retries: 3, }, EDNS: util.EDNS{ EnableEDNS: false, @@ -99,7 +98,7 @@ func TestRealPrint(t *testing.T) { Authority: true, Additional: true, Statistics: true, - UcodeTranslate: false, + UcodeTranslate: true, TTL: true, HumanTTL: true, ShowQuery: true, @@ -110,8 +109,7 @@ func TestRealPrint(t *testing.T) { Type: dns.StringToType["NS"], Class: 1, Name: "freecumextremist.com.", - Timeout: 0, - Retries: 0, + Retries: 3, }, EDNS: util.EDNS{ EnableEDNS: false, @@ -172,8 +170,7 @@ func TestRealPrint(t *testing.T) { Type: dns.StringToType["A"], Class: 1, Name: "froth.zone.", - Timeout: 0, - Retries: 0, + Retries: 3, }, EDNS: util.EDNS{ EnableEDNS: true, diff --git a/pkg/query/query_test.go b/pkg/query/query_test.go index 042f25d..4ec40cc 100644 --- a/pkg/query/query_test.go +++ b/pkg/query/query_test.go @@ -14,115 +14,128 @@ import ( func TestCreateQ(t *testing.T) { t.Parallel() - in := []util.Options{ + tests := []struct { + name string + opts util.Options + }{ { - Logger: util.InitLogger(0), - HeaderFlags: util.HeaderFlags{ - Z: true, - }, - - YAML: true, - - Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeA, - Name: "example.com.", - }, - Display: util.Display{ - Comments: true, - Question: true, - Opt: true, - Answer: true, - Authority: true, - Additional: true, - Statistics: true, - ShowQuery: true, - }, - EDNS: util.EDNS{ - ZFlag: 1, - BufSize: 1500, - EnableEDNS: true, - Cookie: true, - DNSSEC: true, - Expire: true, - KeepOpen: true, - Nsid: true, - Padding: true, - Version: 0, + "1", + util.Options{ + Logger: util.InitLogger(0), + HeaderFlags: util.HeaderFlags{ + Z: true, + }, + YAML: true, + Request: util.Request{ + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, + }, + Display: util.Display{ + Comments: true, + Question: true, + Opt: true, + Answer: true, + Authority: true, + Additional: true, + Statistics: true, + ShowQuery: true, + }, + EDNS: util.EDNS{ + ZFlag: 1, + BufSize: 1500, + EnableEDNS: true, + Cookie: true, + DNSSEC: true, + Expire: true, + KeepOpen: true, + Nsid: true, + Padding: true, + Version: 0, + }, }, }, { - Logger: util.InitLogger(0), - HeaderFlags: util.HeaderFlags{ - Z: true, - }, - XML: true, + "2", + util.Options{ + Logger: util.InitLogger(0), + HeaderFlags: util.HeaderFlags{ + Z: true, + }, + XML: true, - Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeA, - Name: "example.com.", - }, - Display: util.Display{ - Comments: true, - Question: true, - Opt: true, - Answer: true, - Authority: true, - Additional: true, - Statistics: true, - UcodeTranslate: true, - ShowQuery: true, + Request: util.Request{ + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, + }, + Display: util.Display{ + Comments: true, + Question: true, + Opt: true, + Answer: true, + Authority: true, + Additional: true, + Statistics: true, + UcodeTranslate: true, + ShowQuery: true, + }, }, }, { - Logger: util.InitLogger(0), - JSON: true, - QUIC: true, + "3", + util.Options{ + Logger: util.InitLogger(0), + JSON: true, + QUIC: true, - Request: util.Request{ - Server: "dns.adguard.com", - Port: 853, - Type: dns.TypeA, - Name: "example.com.", - }, - Display: util.Display{ - Comments: true, - Question: true, - Opt: true, - Answer: true, - Authority: true, - Additional: true, - Statistics: true, - ShowQuery: true, - }, - EDNS: util.EDNS{ - EnableEDNS: true, - DNSSEC: true, - Cookie: true, - Expire: true, - Nsid: true, + Request: util.Request{ + Server: "dns.adguard.com", + Port: 853, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, + }, + Display: util.Display{ + Comments: true, + Question: true, + Opt: true, + Answer: true, + Authority: true, + Additional: true, + Statistics: true, + ShowQuery: true, + }, + EDNS: util.EDNS{ + EnableEDNS: true, + DNSSEC: true, + Cookie: true, + Expire: true, + Nsid: true, + }, }, }, } - for _, opt := range in { - opt := opt + for _, test := range tests { + test := test - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { t.Parallel() - res, err := query.CreateQuery(opt) + res, err := query.CreateQuery(test.opts) assert.NilError(t, err) assert.Assert(t, res != util.Response{}) - str, err := query.PrintSpecial(res, opt) + str, err := query.PrintSpecial(res, test.opts) assert.NilError(t, err) assert.Assert(t, str != "") - str, err = query.ToString(res, opt) + str, err = query.ToString(res, test.opts) assert.NilError(t, err) assert.Assert(t, str != "") }) diff --git a/pkg/resolvers/DNSCrypt_test.go b/pkg/resolvers/DNSCrypt_test.go index 763dde0..98f5e2f 100644 --- a/pkg/resolvers/DNSCrypt_test.go +++ b/pkg/resolvers/DNSCrypt_test.go @@ -15,42 +15,49 @@ func TestDNSCrypt(t *testing.T) { t.Parallel() tests := []struct { - opt util.Options + name string + opts util.Options }{ { + "Valid", util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, Request: util.Request{ - Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - Type: dns.TypeA, - Name: "example.com.", + Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, }, }, }, { + "Valid (TCP)", util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, TCP: true, IPv4: true, Request: util.Request{ - Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 3, }, }, }, { + "Invalid", util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, TCP: true, IPv4: true, Request: util.Request{ - Server: "QMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "QMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 0, }, }, }, @@ -59,10 +66,10 @@ func TestDNSCrypt(t *testing.T) { for _, test := range tests { test := test - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { t.Parallel() - res, err := query.CreateQuery(test.opt) + res, err := query.CreateQuery(test.opts) if err == nil { assert.Assert(t, res != util.Response{}) } else { diff --git a/pkg/resolvers/HTTPS.go b/pkg/resolvers/HTTPS.go index e4c4388..79cc2ea 100644 --- a/pkg/resolvers/HTTPS.go +++ b/pkg/resolvers/HTTPS.go @@ -52,7 +52,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { } if res.StatusCode != http.StatusOK { - return util.Response{}, &errHTTPStatus{res.StatusCode} + return util.Response{}, &ErrHTTPStatus{res.StatusCode} } resolver.opts.Logger.Debug("https: reading response") @@ -79,10 +79,11 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) { return resp, nil } -type errHTTPStatus struct { +// ErrHTTPStatus is returned when DoH returns a bad status code. +type ErrHTTPStatus struct { code int } -func (e *errHTTPStatus) Error() string { +func (e *ErrHTTPStatus) Error() string { return fmt.Sprintf("doh server responded with HTTP %d", e.code) } diff --git a/pkg/resolvers/HTTPS_test.go b/pkg/resolvers/HTTPS_test.go index 533c12d..2c3742a 100644 --- a/pkg/resolvers/HTTPS_test.go +++ b/pkg/resolvers/HTTPS_test.go @@ -3,6 +3,7 @@ package resolvers_test import ( + "errors" "testing" "git.froth.zone/sam/awl/pkg/resolvers" @@ -11,105 +12,87 @@ import ( "gotest.tools/v3/assert" ) -func TestResolveHTTPS(t *testing.T) { +func TestHTTPS(t *testing.T) { t.Parallel() - var err error - - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "https://dns9.quad9.net/dns-query", - Type: dns.TypeA, - Name: "git.froth.zone.", + tests := []struct { + name string + opts util.Options + }{ + { + "Good", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "https://dns9.quad9.net/dns-query", + Type: dns.TypeA, + Name: "git.froth.zone.", + Retries: 3, + }, + }, + }, + { + "404", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "https://dns9.quad9.net/dns", + Type: dns.TypeA, + Name: "git.froth.zone.", + }, + }, + }, + { + "Bad request domain", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns9.quad9.net/dns-query", + Type: dns.TypeA, + Name: "git.froth.zone", + }, + }, + }, + { + "Bad server domain", + util.Options{ + HTTPS: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns9..quad9.net/dns-query", + Type: dns.TypeA, + Name: "git.froth.zone.", + }, + }, }, } - // testCase := util.Request{Server: "https://dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."} - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) -} + for _, test := range tests { + test := test -func Test2ResolveHTTPS(t *testing.T) { - t.Parallel() + t.Run(test.name, func(t *testing.T) { + t.Parallel() - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."}, + resolver, err := resolvers.LoadResolver(test.opts) + assert.NilError(t, err) + + msg := new(dns.Msg) + msg.SetQuestion(test.opts.Request.Name, test.opts.Request.Type) + // msg = msg.SetQuestion(testCase.Name, testCase.Type) + res, err := resolver.LookUp(msg) + + if err == nil { + assert.NilError(t, err) + assert.Assert(t, res != util.Response{}) + } else { + if errors.Is(err, &resolvers.ErrHTTPStatus{}) { + assert.ErrorContains(t, err, "404") + } + assert.Equal(t, res, util.Response{}) + } + }) } - - var err error - - testCase := util.Request{Type: dns.TypeA, Name: "git.froth.zone"} - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(testCase.Name, testCase.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "fully qualified") - assert.Equal(t, res, util.Response{}) -} - -func Test3ResolveHTTPS(t *testing.T) { - t.Parallel() - - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "dns9..quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."}, - } - - var err error - - // testCase := - // if the domain is not canonical, make it canonical - // if !strings.HasSuffix(testCase.Name, ".") { - // testCase.Name = fmt.Sprintf("%s.", testCase.Name) - // } - - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - // msg.SetQuestion(testCase.Name, testCase.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "doh: HTTP request") - assert.Equal(t, res, util.Response{}) -} - -func Test404ResolveHTTPS(t *testing.T) { - t.Parallel() - - var err error - - opts := util.Options{ - HTTPS: true, - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "https://dns9.quad9.net/dns", - Type: dns.TypeA, - Name: "git.froth.zone.", - }, - } - // testCase := util.Request{Server: "https://dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."} - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "404") - assert.Equal(t, res, util.Response{}) } diff --git a/pkg/resolvers/QUIC_test.go b/pkg/resolvers/QUIC_test.go index 0bfe025..0691558 100644 --- a/pkg/resolvers/QUIC_test.go +++ b/pkg/resolvers/QUIC_test.go @@ -3,10 +3,6 @@ package resolvers_test import ( - "fmt" - "net" - "strconv" - "strings" "testing" "time" @@ -19,72 +15,90 @@ import ( func TestQuic(t *testing.T) { t.Parallel() - opts := util.Options{ - QUIC: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "dns.adguard.com", Port: 853}, + tests := []struct { + name string + opts util.Options + }{ + { + "Valid", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns.adguard.com", + Type: dns.TypeNS, + Port: 853, + Timeout: 750 * time.Millisecond, + Retries: 3, + }, + }, + }, + { + "Bad domain", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns.//./,,adguard\a.com", + Port: 853, + Type: dns.TypeA, + Name: "git.froth.zone", + Timeout: 100 * time.Millisecond, + Retries: 0, + }, + }, + }, + { + "Not canonical", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "dns.adguard.com", + Port: 853, + Type: dns.TypeA, + Name: "git.froth.zone", + Timeout: 100 * time.Millisecond, + Retries: 0, + }, + }, + }, + { + "Invalid query domain", + util.Options{ + QUIC: true, + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "example.com", + Port: 853, + Type: dns.TypeA, + Name: "git.froth.zone", + Timeout: 10 * time.Millisecond, + }, + }, + }, } - testCase := util.Request{Server: "dns.//./,,adguard.com", Type: dns.TypeA, Name: "git.froth.zone"} - testCase2 := util.Request{Server: "dns.adguard.com", Type: dns.TypeA, Name: "git.froth.zone"} - var testCases []util.Request + for _, test := range tests { + test := test - testCases = append(testCases, testCase) - testCases = append(testCases, testCase2) + t.Run(test.name, func(t *testing.T) { + t.Parallel() - for i := range testCases { - switch i { - case 0: - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - // if the domain is not canonical, make it canonical - if !strings.HasSuffix(testCase.Name, ".") { - testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name) - } - - msg := new(dns.Msg) - msg.SetQuestion(testCase.Name, testCase.Type) - // msg = msg.SetQuestion(testCase.Name, testCase.Type) - res, err := resolver.LookUp(msg) - - assert.ErrorContains(t, err, "fully qualified") - assert.Equal(t, res, util.Response{}) - case 1: - resolver, err := resolvers.LoadResolver(opts) + resolver, err := resolvers.LoadResolver(test.opts) assert.NilError(t, err) - testCase2.Server = net.JoinHostPort(testCase2.Server, strconv.Itoa(opts.Request.Port)) - - // if the domain is not canonical, make it canonical - if !strings.HasSuffix(testCase2.Name, ".") { - testCase2.Name = fmt.Sprintf("%s.", testCase2.Name) - } - msg := new(dns.Msg) - msg.SetQuestion(testCase2.Name, testCase2.Type) + msg.SetQuestion(test.opts.Request.Name, test.opts.Request.Type) res, err := resolver.LookUp(msg) - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) - } + if err == nil { + assert.NilError(t, err) + assert.Assert(t, res != util.Response{}) + } else { + assert.Assert(t, res == util.Response{}) + } + }) } } - -func TestInvalidQuic(t *testing.T) { - t.Parallel() - - opts := util.Options{ - QUIC: true, - Logger: util.InitLogger(0), - Request: util.Request{Server: "example.com", Port: 853, Type: dns.TypeA, Name: "git.froth.zone", Timeout: 10 * time.Millisecond}, - } - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - res, err := resolver.LookUp(msg) - assert.ErrorContains(t, err, "timeout") - assert.Equal(t, res, util.Response{}) -} diff --git a/pkg/resolvers/general_test.go b/pkg/resolvers/general_test.go index f9441ae..6e7774d 100644 --- a/pkg/resolvers/general_test.go +++ b/pkg/resolvers/general_test.go @@ -3,11 +3,11 @@ package resolvers_test import ( + "os" "testing" "time" "git.froth.zone/sam/awl/pkg/query" - "git.froth.zone/sam/awl/pkg/resolvers" "git.froth.zone/sam/awl/pkg/util" "github.com/miekg/dns" "gotest.tools/v3/assert" @@ -16,90 +16,95 @@ import ( func TestResolve(t *testing.T) { t.Parallel() - opts := util.Options{ - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "8.8.4.1", - Port: 1, - Type: dns.TypeA, - Name: "example.com.", - Timeout: time.Second / 2, - Retries: 0, - }, - } - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - - _, err = resolver.LookUp(msg) - assert.ErrorContains(t, err, "timeout") -} - -func TestTruncate(t *testing.T) { - t.Parallel() - - opts := util.Options{ - Logger: util.InitLogger(0), - IPv4: true, - Request: util.Request{ - Server: "madns.binarystar.systems", - Port: 5301, - Type: dns.TypeTXT, - Name: "limit.txt.example.", - }, - } - resolver, err := resolvers.LoadResolver(opts) - assert.NilError(t, err) - - msg := new(dns.Msg) - msg.SetQuestion(opts.Request.Name, opts.Request.Type) - res, err := resolver.LookUp(msg) - - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) -} - -func TestResolveAgain(t *testing.T) { - t.Parallel() - tests := []struct { - opt util.Options + name string + opts util.Options }{ { + "UDP", + util.Options{ + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 3, + }, + }, + }, + { + "UDP (Bad Cookie)", + util.Options{ + Logger: util.InitLogger(0), + BadCookie: false, + Request: util.Request{ + Server: "b.root-servers.net", + Port: 53, + Type: dns.TypeNS, + Name: "example.com.", + Retries: 3, + }, + EDNS: util.EDNS{ + EnableEDNS: true, + Cookie: true, + }, + }, + }, + { + "UDP (Truncated)", + util.Options{ + Logger: util.InitLogger(0), + IPv4: true, + Request: util.Request{ + Server: "madns.binarystar.systems", + Port: 5301, + Type: dns.TypeTXT, + Name: "limit.txt.example.", + Retries: 3, + }, + }, + }, + { + "TCP", util.Options{ Logger: util.InitLogger(0), TCP: true, Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeA, - Name: "example.com.", - }, - }, - }, - { - util.Options{ - Logger: util.InitLogger(0), - Request: util.Request{ - Server: "8.8.4.4", - Port: 53, - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "8.8.4.4", + Port: 53, + Type: dns.TypeA, + Name: "example.com.", + Retries: 3, }, }, }, { + "TLS", util.Options{ Logger: util.InitLogger(0), TLS: true, Request: util.Request{ - Server: "dns.google", - Port: 853, - Type: dns.TypeAAAA, - Name: "example.com.", + Server: "dns.google", + Port: 853, + Type: dns.TypeAAAA, + Name: "example.com.", + Retries: 3, + }, + }, + }, + { + "Timeout", + util.Options{ + Logger: util.InitLogger(0), + Request: util.Request{ + Server: "8.8.4.1", + Port: 1, + Type: dns.TypeA, + Name: "example.com.", + Timeout: time.Millisecond * 100, + Retries: 0, }, }, }, @@ -108,11 +113,16 @@ func TestResolveAgain(t *testing.T) { for _, test := range tests { test := test - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { t.Parallel() - res, err := query.CreateQuery(test.opt) - assert.NilError(t, err) - assert.Assert(t, res != util.Response{}) + + res, err := query.CreateQuery(test.opts) + if err == nil { + assert.NilError(t, err) + assert.Assert(t, res != util.Response{}) + } else { + assert.ErrorIs(t, err, os.ErrDeadlineExceeded) + } }) } } diff --git a/pkg/util/options_test.go b/pkg/util/options_test.go index 1e28e95..c04060e 100644 --- a/pkg/util/options_test.go +++ b/pkg/util/options_test.go @@ -17,21 +17,20 @@ func TestSubnet(t *testing.T) { "::0/0", "0", "127.0.0.1/32", + "Invalid", } for _, test := range subnet { test := test + t.Run(test, func(t *testing.T) { t.Parallel() err := util.ParseSubnet(test, new(util.Options)) - assert.NilError(t, err) + if err != nil { + assert.ErrorContains(t, err, "invalid CIDR address") + } else { + assert.NilError(t, err) + } }) } } - -func TestInvalidSub(t *testing.T) { - t.Parallel() - - err := util.ParseSubnet("1", new(util.Options)) - assert.ErrorContains(t, err, "invalid CIDR address") -} diff --git a/pkg/util/reverseDNS_test.go b/pkg/util/reverseDNS_test.go index 4725b4d..42099e6 100644 --- a/pkg/util/reverseDNS_test.go +++ b/pkg/util/reverseDNS_test.go @@ -10,25 +10,43 @@ import ( "gotest.tools/v3/assert" ) -var ( - PTR = dns.StringToType["PTR"] - NAPTR = dns.StringToType["NAPTR"] -) - -func TestIPv4(t *testing.T) { +func TestPTR(t *testing.T) { t.Parallel() - act, err := util.ReverseDNS("8.8.4.4", PTR) - assert.NilError(t, err) - assert.Equal(t, act, "4.4.8.8.in-addr.arpa.", "IPv4 reverse") -} + tests := []struct { + name string + in string + expected string + }{ + { + "IPv4", + "8.8.4.4", "4.4.8.8.in-addr.arpa.", + }, + { + "IPv6", + "2606:4700:4700::1111", "1.1.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.4.0.0.7.4.6.0.6.2.ip6.arpa.", + }, + { + "Inavlid value", + "AAAAA", "", + }, + } -func TestIPv6(t *testing.T) { - t.Parallel() + for _, test := range tests { + test := test - act, err := util.ReverseDNS("2606:4700:4700::1111", PTR) - assert.NilError(t, err) - assert.Equal(t, act, "1.1.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.4.0.0.7.4.6.0.6.2.ip6.arpa.", "IPv6 reverse") + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + act, err := util.ReverseDNS(test.in, dns.StringToType["PTR"]) + if err == nil { + assert.NilError(t, err) + } else { + assert.ErrorContains(t, err, "unrecognized address") + } + assert.Equal(t, act, test.expected) + }) + } } func TestNAPTR(t *testing.T) { @@ -48,23 +66,14 @@ func TestNAPTR(t *testing.T) { test := test t.Run(test.in, func(t *testing.T) { t.Parallel() - act, err := util.ReverseDNS(test.in, NAPTR) + act, err := util.ReverseDNS(test.in, dns.StringToType["NAPTR"]) assert.NilError(t, err) assert.Equal(t, test.want, act) }) } } -func TestInvalid(t *testing.T) { - t.Parallel() - - _, err := util.ReverseDNS("AAAAA", 1) - assert.ErrorContains(t, err, "invalid value AAAAA given") -} - -func TestInvalid2(t *testing.T) { - t.Parallel() - - _, err := util.ReverseDNS("1.0", PTR) - assert.ErrorContains(t, err, "PTR reverse") +func TestInvalidAll(t *testing.T) { + _, err := util.ReverseDNS("q", 15236) + assert.ErrorContains(t, err, "invalid value") }