chore(Refactor) #110

Merged
grumbulon merged 15 commits from refactor into master 2022-09-24 23:11:10 +00:00
19 changed files with 592 additions and 533 deletions
Showing only changes of commit 3eb2955bcf - Show all commits

View file

@ -4,9 +4,7 @@
root = true
[*]
indent_style = space
indent_size = 2
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = false
insert_final_newline = false
trim_trailing_whitespace = true
insert_final_newline = true

View file

@ -18,9 +18,9 @@ import (
// ParseCLI parses arguments given from the CLI and passes them into an `Options`
// struct.
func ParseCLI(args []string, version string) (util.Options, error) {
flag.CommandLine = flag.NewFlagSet(args[0], flag.ContinueOnError)
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
flag.Usage = func() {
flagSet.Usage = func() {
fmt.Println(`awl - drill, writ small
Usage: awl name [@server] [record]
@ -31,80 +31,80 @@ func ParseCLI(args []string, version string) (util.Options, error) {
Dig-like +[no]commands are also supported, see dig(1) or dig -h
Options:`)
flag.PrintDefaults()
flagSet.PrintDefaults()
}
// CLI flags
//
// Remember, when adding a flag edit the manpage and the completions :)
var (
port = flag.Int("port", 0, "`port` to make DNS query (default: 53 for UDP/TCP, 853 for TLS/QUIC)", flag.OptShorthand('p'), flag.OptDisablePrintDefault(true))
query = flag.String("query", "", "domain name to `query` (default: .)", flag.OptShorthand('q'))
class = flag.String("class", "IN", "DNS `class` to query", flag.OptShorthand('c'))
qType = flag.String("qType", "", "`type` to query (default: A)", flag.OptShorthand('t'))
port = flagSet.Int("port", 0, "`port` to make DNS query (default: 53 for UDP/TCP, 853 for TLS/QUIC)", flag.OptShorthand('p'), flag.OptDisablePrintDefault(true))
query = flagSet.String("query", "", "domain name to `query` (default: .)", flag.OptShorthand('q'))
class = flagSet.String("class", "IN", "DNS `class` to query", flag.OptShorthand('c'))
qType = flagSet.String("qType", "", "`type` to query (default: A)", flag.OptShorthand('t'))
ipv4 = flag.Bool("4", false, "force IPv4", flag.OptShorthand('4'))
ipv6 = flag.Bool("6", false, "force IPv6", flag.OptShorthand('6'))
reverse = flag.Bool("reverse", false, "do a reverse lookup", flag.OptShorthand('x'))
ipv4 = flagSet.Bool("4", false, "force IPv4", flag.OptShorthand('4'))
ipv6 = flagSet.Bool("6", false, "force IPv6", flag.OptShorthand('6'))
reverse = flagSet.Bool("reverse", false, "do a reverse lookup", flag.OptShorthand('x'))
timeout = flag.Float32("timeout", 1, "Timeout, in `seconds`")
retry = flag.Int("retries", 2, "number of `times` to retry")
timeout = flagSet.Float32("timeout", 1, "Timeout, in `seconds`")
retry = flagSet.Int("retries", 2, "number of `times` to retry")
edns = flag.Bool("no-edns", false, "disable EDNS entirely")
ednsVer = flag.Uint8("edns-ver", 0, "set EDNS version")
dnssec = flag.Bool("dnssec", false, "enable DNSSEC", flag.OptShorthand('D'))
expire = flag.Bool("expire", false, "set EDNS expire")
nsid = flag.Bool("nsid", false, "set EDNS NSID", flag.OptShorthand('n'))
cookie = flag.Bool("no-cookie", false, "disable sending EDNS cookie (default: cookie sent)")
tcpKeepAlive = flag.Bool("keep-alive", false, "send EDNS TCP keep-alive")
udpBufSize = flag.Uint16("buffer-size", 1232, "set EDNS UDP buffer size", flag.OptShorthand('b'))
mbzflag = flag.String("zflag", "0", "set EDNS z-flag `value`")
subnet = flag.String("subnet", "", "set EDNS client subnet")
padding = flag.Bool("pad", false, "set EDNS padding")
edns = flagSet.Bool("no-edns", false, "disable EDNS entirely")
ednsVer = flagSet.Uint8("edns-ver", 0, "set EDNS version")
dnssec = flagSet.Bool("dnssec", false, "enable DNSSEC", flag.OptShorthand('D'))
expire = flagSet.Bool("expire", false, "set EDNS expire")
nsid = flagSet.Bool("nsid", false, "set EDNS NSID", flag.OptShorthand('n'))
cookie = flagSet.Bool("no-cookie", false, "disable sending EDNS cookie (default: cookie sent)")
tcpKeepAlive = flagSet.Bool("keep-alive", false, "send EDNS TCP keep-alive")
udpBufSize = flagSet.Uint16("buffer-size", 1232, "set EDNS UDP buffer size", flag.OptShorthand('b'))
mbzflag = flagSet.String("zflag", "0", "set EDNS z-flag `value`")
subnet = flagSet.String("subnet", "", "set EDNS client subnet")
padding = flagSet.Bool("pad", false, "set EDNS padding")
badCookie = flag.Bool("no-bad-cookie", false, "ignore BADCOOKIE EDNS responses (default: retry with correct cookie")
truncate = flag.Bool("no-truncate", false, "ignore truncation if a UDP request truncates (default: retry with TCP)")
badCookie = flagSet.Bool("no-bad-cookie", false, "ignore BADCOOKIE EDNS responses (default: retry with correct cookie")
truncate = flagSet.Bool("no-truncate", false, "ignore truncation if a UDP request truncates (default: retry with TCP)")
tcp = flag.Bool("tcp", false, "use TCP")
dnscrypt = flag.Bool("dnscrypt", false, "use DNSCrypt")
tls = flag.Bool("tls", false, "use DNS-over-TLS", flag.OptShorthand('T'))
https = flag.Bool("https", false, "use DNS-over-HTTPS", flag.OptShorthand('H'))
quic = flag.Bool("quic", false, "use DNS-over-QUIC", flag.OptShorthand('Q'))
tcp = flagSet.Bool("tcp", false, "use TCP")
dnscrypt = flagSet.Bool("dnscrypt", false, "use DNSCrypt")
tls = flagSet.Bool("tls", false, "use DNS-over-TLS", flag.OptShorthand('T'))
https = flagSet.Bool("https", false, "use DNS-over-HTTPS", flag.OptShorthand('H'))
quic = flagSet.Bool("quic", false, "use DNS-over-QUIC", flag.OptShorthand('Q'))
tlsHost = flag.String("tls-host", "", "Server name to use for TLS verification")
noVerify = flag.Bool("tls-no-verify", false, "Disable TLS cert verification")
tlsHost = flagSet.String("tls-host", "", "Server name to use for TLS verification")
noVerify = flagSet.Bool("tls-no-verify", false, "Disable TLS cert verification")
aaflag = flag.Bool("aa", false, "set/unset AA (Authoratative Answer) flag (default: not set)")
adflag = flag.Bool("ad", false, "set/unset AD (Authenticated Data) flag (default: not set)")
cdflag = flag.Bool("cd", false, "set/unset CD (Checking Disabled) flag (default: not set)")
qrflag = flag.Bool("qr", false, "set/unset QR (QueRy) flag (default: not set)")
rdflag = flag.Bool("rd", true, "set/unset RD (Recursion Desired) flag (default: set)", flag.OptDisablePrintDefault(true))
raflag = flag.Bool("ra", false, "set/unset RA (Recursion Available) flag (default: not set)")
tcflag = flag.Bool("tc", false, "set/unset TC (TrunCated) flag (default: not set)")
zflag = flag.Bool("z", false, "set/unset Z (Zero) flag (default: not set)", flag.OptShorthand('z'))
aaflag = flagSet.Bool("aa", false, "set/unset AA (Authoratative Answer) flag (default: not set)")
adflag = flagSet.Bool("ad", false, "set/unset AD (Authenticated Data) flag (default: not set)")
cdflag = flagSet.Bool("cd", false, "set/unset CD (Checking Disabled) flag (default: not set)")
qrflag = flagSet.Bool("qr", false, "set/unset QR (QueRy) flag (default: not set)")
rdflag = flagSet.Bool("rd", true, "set/unset RD (Recursion Desired) flag (default: set)", flag.OptDisablePrintDefault(true))
raflag = flagSet.Bool("ra", false, "set/unset RA (Recursion Available) flag (default: not set)")
tcflag = flagSet.Bool("tc", false, "set/unset TC (TrunCated) flag (default: not set)")
zflag = flagSet.Bool("z", false, "set/unset Z (Zero) flag (default: not set)", flag.OptShorthand('z'))
short = flag.Bool("short", false, "print just the results", flag.OptShorthand('s'))
json = flag.Bool("json", false, "print the result(s) as JSON", flag.OptShorthand('j'))
xml = flag.Bool("xml", false, "print the result(s) as XML", flag.OptShorthand('X'))
yaml = flag.Bool("yaml", false, "print the result(s) as yaml", flag.OptShorthand('y'))
short = flagSet.Bool("short", false, "print just the results", flag.OptShorthand('s'))
json = flagSet.Bool("json", false, "print the result(s) as JSON", flag.OptShorthand('j'))
xml = flagSet.Bool("xml", false, "print the result(s) as XML", flag.OptShorthand('X'))
yaml = flagSet.Bool("yaml", false, "print the result(s) as yaml", flag.OptShorthand('y'))
noC = flag.Bool("no-comments", false, "disable printing the comments")
noQ = flag.Bool("no-question", false, "disable printing the question section")
noOpt = flag.Bool("no-opt", false, "disable printing the OPT pseudosection")
noAns = flag.Bool("no-answer", false, "disable printing the answer section")
noAuth = flag.Bool("no-authority", false, "disable printing the authority section")
noAdd = flag.Bool("no-additional", false, "disable printing the additional section")
noStats = flag.Bool("no-statistics", false, "disable printing the statistics section")
noC = flagSet.Bool("no-comments", false, "disable printing the comments")
noQ = flagSet.Bool("no-question", false, "disable printing the question section")
noOpt = flagSet.Bool("no-opt", false, "disable printing the OPT pseudosection")
noAns = flagSet.Bool("no-answer", false, "disable printing the answer section")
noAuth = flagSet.Bool("no-authority", false, "disable printing the authority section")
noAdd = flagSet.Bool("no-additional", false, "disable printing the additional section")
noStats = flagSet.Bool("no-statistics", false, "disable printing the statistics section")
verbosity = flag.Int("verbosity", 1, "sets verbosity `level`", flag.OptShorthand('v'), flag.OptNoOptDefVal("2"))
versionFlag = flag.Bool("version", false, "print version information", flag.OptShorthand('V'))
verbosity = flagSet.Int("verbosity", 1, "sets verbosity `level`", flag.OptShorthand('v'), flag.OptNoOptDefVal("2"))
versionFlag = flagSet.Bool("version", false, "print version information", flag.OptShorthand('V'))
)
// Don't sort the flags when -h is given
flag.CommandLine.SortFlags = false
flagSet.SortFlags = false
// Parse the flags
if err := flag.CommandLine.Parse(args[1:]); err != nil {
if err := flagSet.Parse(args[1:]); err != nil {
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
}
@ -194,7 +194,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
// Parse all the arguments that don't start with - or --
// This includes the dig-style (+) options
err = ParseMiscArgs(flag.Args(), &opts)
err = ParseMiscArgs(flagSet.Args(), &opts)
if err != nil {
return opts, err
}

View file

@ -11,51 +11,64 @@ import (
)
func TestEmpty(t *testing.T) {
t.Parallel()
args := []string{"awl", "-4"}
opts, err := cli.ParseCLI(args, "TEST")
assert.NilError(t, err)
assert.Equal(t, opts.Request.Port, 53)
assert.Assert(t, opts.IPv4)
assert.Equal(t, opts.Request.Port, 53)
}
func TestTLSPort(t *testing.T) {
t.Parallel()
args := []string{"awl", "-T"}
opts, err := cli.ParseCLI(args, "TEST")
assert.NilError(t, err)
assert.Equal(t, opts.Request.Port, 853)
}
func TestSubnet(t *testing.T) {
args := []string{"awl", "--subnet", "127.0.0.1/32"}
func TestValidSubnet(t *testing.T) {
t.Parallel()
opts, err := cli.ParseCLI(args, "TEST")
tests := []struct {
args []string
want uint16
}{
{[]string{"awl", "--subnet", "127.0.0.1/32"}, uint16(1)},
{[]string{"awl", "--subnet", "0"}, uint16(1)},
{[]string{"awl", "--subnet", "::/0"}, uint16(2)},
}
assert.NilError(t, err)
assert.Equal(t, opts.EDNS.Subnet.Family, uint16(1))
for _, test := range tests {
test := test
args = []string{"awl", "--subnet", "0"}
t.Run(test.args[2], func(t *testing.T) {
t.Parallel()
opts, err = cli.ParseCLI(args, "TEST")
assert.NilError(t, err)
assert.Equal(t, opts.EDNS.Subnet.Family, uint16(1))
opts, err := cli.ParseCLI(test.args, "TEST")
args = []string{"awl", "--subnet", "::/0"}
assert.NilError(t, err)
assert.Equal(t, opts.EDNS.Subnet.Family, test.want)
})
}
}
opts, err = cli.ParseCLI(args, "TEST")
assert.NilError(t, err)
assert.Equal(t, opts.EDNS.Subnet.Family, uint16(2))
func TestInvalidSubnet(t *testing.T) {
t.Parallel()
args = []string{"awl", "--subnet", "/"}
args := []string{"awl", "--subnet", "/"}
opts, err = cli.ParseCLI(args, "TEST")
_, err := cli.ParseCLI(args, "TEST")
assert.ErrorContains(t, err, "EDNS subnet")
}
func TestMBZ(t *testing.T) {
t.Parallel()
args := []string{"awl", "--zflag", "G"}
_, err := cli.ParseCLI(args, "TEST")
@ -64,6 +77,8 @@ func TestMBZ(t *testing.T) {
}
func TestInvalidFlag(t *testing.T) {
t.Parallel()
args := []string{"awl", "--treebug"}
_, err := cli.ParseCLI(args, "TEST")
@ -72,6 +87,8 @@ func TestInvalidFlag(t *testing.T) {
}
func TestInvalidDig(t *testing.T) {
t.Parallel()
args := []string{"awl", "+a"}
_, err := cli.ParseCLI(args, "TEST")
@ -80,6 +97,8 @@ func TestInvalidDig(t *testing.T) {
}
func TestVersion(t *testing.T) {
t.Parallel()
args := []string{"awl", "--version"}
_, err := cli.ParseCLI(args, "test")
@ -88,6 +107,8 @@ func TestVersion(t *testing.T) {
}
func TestTimeout(t *testing.T) {
t.Parallel()
args := [][]string{
{"awl", "+timeout=0"},
{"awl", "--timeout", "0"},
@ -95,14 +116,20 @@ func TestTimeout(t *testing.T) {
for _, test := range args {
test := test
opt, err := cli.ParseCLI(test, "TEST")
t.Run(test[1], func(t *testing.T) {
t.Parallel()
assert.NilError(t, err)
assert.Equal(t, opt.Request.Timeout, time.Second/2)
opt, err := cli.ParseCLI(test, "TEST")
assert.NilError(t, err)
assert.Equal(t, opt.Request.Timeout, time.Second/2)
})
}
}
func TestRetries(t *testing.T) {
t.Parallel()
args := [][]string{
{"awl", "+retry=-2"},
{"awl", "+tries=-2"},
@ -111,10 +138,14 @@ func TestRetries(t *testing.T) {
for _, test := range args {
test := test
opt, err := cli.ParseCLI(test, "TEST")
t.Run(test[1], func(t *testing.T) {
t.Parallel()
assert.NilError(t, err)
assert.Equal(t, opt.Request.Retries, 0)
opt, err := cli.ParseCLI(test, "TEST")
assert.NilError(t, err)
assert.Equal(t, opt.Request.Retries, 0)
})
}
}

View file

@ -45,7 +45,7 @@ func ParseMiscArgs(args []string, opts *util.Options) error {
opts.Logger.Info("DNSCrypt implicitly set")
case strings.HasPrefix(arg, "tcp://"):
opts.TCP = true
opts.Request.Server = strings.TrimPrefix(arg, "udp://")
opts.Request.Server = strings.TrimPrefix(arg, "tcp://")
opts.Logger.Info("TCP implicitly set")
case strings.HasPrefix(arg, "udp://"):
opts.Request.Server = strings.TrimPrefix(arg, "udp://")

View file

@ -3,7 +3,6 @@
package cli_test
import (
"strconv"
"testing"
"git.froth.zone/sam/awl/cli"
@ -87,7 +86,7 @@ func TestDefaultServer(t *testing.T) {
in string
want string
}{
{"DNSCRYPT", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"},
{"DNSCrypt", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"},
{"TLS", "dns.google"},
{"HTTPS", "https://dns.cloudflare.com/dns-query"},
{"QUIC", "dns.adguard.com"},
@ -95,13 +94,14 @@ func TestDefaultServer(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.in, func(t *testing.T) {
t.Parallel()
args := []string{}
opts := new(util.Options)
opts.Logger = util.InitLogger(0)
switch test.in {
case "DNSCRYPT":
case "DNSCrypt":
opts.DNSCrypt = true
case "TLS":
opts.TLS = true
@ -121,38 +121,48 @@ func TestFlagSetting(t *testing.T) {
t.Parallel()
tests := []struct {
in []string
in string
expected string
over string
}{
{[]string{"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"}},
{[]string{"@tls://dns.google"}},
{[]string{"@https://dns.cloudflare.com/dns-query"}},
{[]string{"@quic://dns.adguard.com"}},
{[]string{"@tcp://dns.froth.zone"}},
{[]string{"@udp://dns.example.com"}},
{"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "DNSCrypt"},
{"@tls://dns.google", "dns.google", "TLS"},
{"@https://dns.cloudflare.com/dns-query", "https://dns.cloudflare.com/dns-query", "HTTPS"},
{"@quic://dns.adguard.com", "dns.adguard.com", "QUIC"},
{"@tcp://dns.froth.zone", "dns.froth.zone", "TCP"},
{"@udp://dns.example.com", "dns.example.com", "UDP"},
}
for i, test := range tests {
for _, test := range tests {
test := test
i := i
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run(test.over, func(t *testing.T) {
t.Parallel()
opts := new(util.Options)
opts.Logger = util.InitLogger(0)
t.Parallel()
err := cli.ParseMiscArgs(test.in, opts)
err := cli.ParseMiscArgs([]string{test.in}, opts)
assert.NilError(t, err)
switch i {
case 0:
switch test.over {
case "DNSCrypt":
assert.Assert(t, opts.DNSCrypt)
case 1:
assert.Equal(t, opts.Request.Server, test.expected)
case "TLS":
assert.Assert(t, opts.TLS)
case 2:
assert.Equal(t, opts.Request.Server, test.expected)
case "HTTPS":
assert.Assert(t, opts.HTTPS)
case 3:
assert.Equal(t, opts.Request.Server, test.expected)
case "QUIC":
assert.Assert(t, opts.QUIC)
case 4:
assert.Equal(t, opts.Request.Server, test.expected)
case "TCP":
assert.Assert(t, opts.TCP)
case 5:
assert.Equal(t, opts.Request.Server, test.expected)
case "UDP":
assert.Assert(t, true)
assert.Equal(t, opts.Request.Server, test.expected)
}
})
}

View file

@ -4,53 +4,22 @@
package conf_test
import (
"runtime"
"testing"
"git.froth.zone/sam/awl/conf"
"gotest.tools/v3/assert"
)
func TestGetPlan9Config(t *testing.T) {
func TestPlan9Config(t *testing.T) {
t.Parallel()
if runtime.GOOS != "plan9" {
t.Skip("Not running Plan 9, skipping")
}
ndbs := []struct {
in string
want string
}{
{`ip=192.168.122.45 ipmask=255.255.255.0 ipgw=192.168.122.1
sys=chog9
dns=192.168.122.1`, "192.168.122.1"},
{`ipnet=murray-hill ip=135.104.0.0 ipmask=255.255.0.0
dns=135.104.10.1
ntp=ntp.cs.bell-labs.com
ipnet=plan9 ip=135.104.9.0 ipmask=255.255.255.0
ntp=oncore.cs.bell-labs.com
smtp=smtp1.cs.bell-labs.com
ip=135.104.9.6 sys=anna dom=anna.cs.bell-labs.com
smtp=smtp2.cs.bell-labs.com`, "135.104.10.1"},
}
conf, err := conf.GetDNSConfig()
for _, ndb := range ndbs {
// Go is a little quirky
ndb := ndb
t.Run(ndb.want, func(t *testing.T) {
t.Parallel()
act, err := conf.GetPlan9Config(ndb.in)
assert.NilError(t, err)
assert.Equal(t, ndb.want, act.Servers[0])
})
}
invalid := `sys = spindle
dom=spindle.research.bell-labs.com
bootf=/mips/9powerboot
ip=135.104.117.32 ether=080069020677
proto=il`
act, err := conf.GetPlan9Config(invalid)
assert.ErrorContains(t, err, "no DNS servers found")
assert.Assert(t, act == nil)
assert.NilError(t, err)
assert.Assert(t, len(conf.Servers) != 0)
}

View file

@ -13,8 +13,10 @@ import (
"gotest.tools/v3/assert"
)
func TestNonWinConfig(t *testing.T) {
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
func TestUnixConfig(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "js" || runtime.GOOS == "zos" {
t.Skip("Not running Unix-like, skipping")
}

4
go.mod
View file

@ -8,8 +8,8 @@ require (
github.com/lucas-clemente/quic-go v0.29.0
github.com/miekg/dns v1.1.50
github.com/stefansundin/go-zflag v1.1.1
golang.org/x/net v0.0.0-20220909164309-bea034e7d591
golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41
golang.org/x/net v0.0.0-20220921203646-d300de134e69
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8
gopkg.in/yaml.v3 v3.0.1
gotest.tools/v3 v3.3.0
)

12
go.sum
View file

@ -94,6 +94,16 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220909164309-bea034e7d591 h1:D0B/7al0LLrVC8aWF4+oxpv/m8bc7ViFfVS8/gXGdqI=
golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220919171627-f8f703f97925 h1:5Vhms/MnzRfom2tX/lHt58o74fzobhcw2FDjPdoGoow=
golang.org/x/net v0.0.0-20220919171627-f8f703f97925/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220919232410-f2f64ebce3c1 h1:TWZxd/th7FbRSMret2MVQdlI8uT49QEtwZdvJrxjEHU=
golang.org/x/net v0.0.0-20220919232410-f2f64ebce3c1/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220920203100-d0c6ba3f52d9 h1:asZqf0wXastQr+DudYagQS8uBO8bHKeYD1vbAvGmFL8=
golang.org/x/net v0.0.0-20220920203100-d0c6ba3f52d9/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220921155015-db77216a4ee9 h1:SdDGdqRuKrF2R4XGcnPzcvZ63c/55GvhoHUus0o+BNI=
golang.org/x/net v0.0.0-20220921155015-db77216a4ee9/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220921203646-d300de134e69 h1:hUJpGDpnfwdJW8iNypFjmSY0sCBEL+spFTZ2eO+Sfps=
golang.org/x/net v0.0.0-20220921203646-d300de134e69/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -117,6 +127,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41 h1:ohgcoMbSofXygzo6AD2I1kz3BFmW1QArPYTtwEM3UXc=
golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View file

@ -3,31 +3,35 @@
package main
import (
"os"
"testing"
"github.com/stefansundin/go-zflag"
"gotest.tools/v3/assert"
)
func TestMain(t *testing.T) { //nolint: paralleltest // Race conditions
os.Stdout = os.NewFile(0, os.DevNull)
os.Stderr = os.NewFile(0, os.DevNull)
func TestRun(t *testing.T) {
t.Parallel()
args := []string{"awl", "+yaml", "@1.1.1.1"}
args := [][]string{
{"awl", "+yaml", "@1.1.1.1"},
{"awl", "+short", "@1.1.1.1"},
}
_, code, err := run(args)
assert.NilError(t, err)
assert.Equal(t, code, 0)
for _, test := range args {
test := test
args = []string{"awl", "+short", "@1.1.1.1"}
_, code, err = run(args)
assert.NilError(t, err)
assert.Equal(t, code, 0)
t.Run("", func(t *testing.T) {
t.Parallel()
_, code, err := run(test)
assert.NilError(t, err)
assert.Equal(t, code, 0)
})
}
}
func TestHelp(t *testing.T) {
t.Parallel()
args := []string{"awl", "-h"}
_, code, err := run(args)

View file

@ -77,8 +77,7 @@ func TestRealPrint(t *testing.T) {
Type: dns.StringToType["NS"],
Class: 1,
Name: "google.com.",
Timeout: 0,
Retries: 0,
Retries: 3,
},
EDNS: util.EDNS{
EnableEDNS: false,
@ -99,7 +98,7 @@ func TestRealPrint(t *testing.T) {
Authority: true,
Additional: true,
Statistics: true,
UcodeTranslate: false,
UcodeTranslate: true,
TTL: true,
HumanTTL: true,
ShowQuery: true,
@ -110,8 +109,7 @@ func TestRealPrint(t *testing.T) {
Type: dns.StringToType["NS"],
Class: 1,
Name: "freecumextremist.com.",
Timeout: 0,
Retries: 0,
Retries: 3,
},
EDNS: util.EDNS{
EnableEDNS: false,
@ -172,8 +170,7 @@ func TestRealPrint(t *testing.T) {
Type: dns.StringToType["A"],
Class: 1,
Name: "froth.zone.",
Timeout: 0,
Retries: 0,
Retries: 3,
},
EDNS: util.EDNS{
EnableEDNS: true,

View file

@ -14,115 +14,128 @@ import (
func TestCreateQ(t *testing.T) {
t.Parallel()
in := []util.Options{
tests := []struct {
name string
opts util.Options
}{
{
Logger: util.InitLogger(0),
HeaderFlags: util.HeaderFlags{
Z: true,
},
YAML: true,
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeA,
Name: "example.com.",
},
Display: util.Display{
Comments: true,
Question: true,
Opt: true,
Answer: true,
Authority: true,
Additional: true,
Statistics: true,
ShowQuery: true,
},
EDNS: util.EDNS{
ZFlag: 1,
BufSize: 1500,
EnableEDNS: true,
Cookie: true,
DNSSEC: true,
Expire: true,
KeepOpen: true,
Nsid: true,
Padding: true,
Version: 0,
"1",
util.Options{
Logger: util.InitLogger(0),
HeaderFlags: util.HeaderFlags{
Z: true,
},
YAML: true,
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeA,
Name: "example.com.",
Retries: 3,
},
Display: util.Display{
Comments: true,
Question: true,
Opt: true,
Answer: true,
Authority: true,
Additional: true,
Statistics: true,
ShowQuery: true,
},
EDNS: util.EDNS{
ZFlag: 1,
BufSize: 1500,
EnableEDNS: true,
Cookie: true,
DNSSEC: true,
Expire: true,
KeepOpen: true,
Nsid: true,
Padding: true,
Version: 0,
},
},
},
{
Logger: util.InitLogger(0),
HeaderFlags: util.HeaderFlags{
Z: true,
},
XML: true,
"2",
util.Options{
Logger: util.InitLogger(0),
HeaderFlags: util.HeaderFlags{
Z: true,
},
XML: true,
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeA,
Name: "example.com.",
},
Display: util.Display{
Comments: true,
Question: true,
Opt: true,
Answer: true,
Authority: true,
Additional: true,
Statistics: true,
UcodeTranslate: true,
ShowQuery: true,
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeA,
Name: "example.com.",
Retries: 3,
},
Display: util.Display{
Comments: true,
Question: true,
Opt: true,
Answer: true,
Authority: true,
Additional: true,
Statistics: true,
UcodeTranslate: true,
ShowQuery: true,
},
},
},
{
Logger: util.InitLogger(0),
JSON: true,
QUIC: true,
"3",
util.Options{
Logger: util.InitLogger(0),
JSON: true,
QUIC: true,
Request: util.Request{
Server: "dns.adguard.com",
Port: 853,
Type: dns.TypeA,
Name: "example.com.",
},
Display: util.Display{
Comments: true,
Question: true,
Opt: true,
Answer: true,
Authority: true,
Additional: true,
Statistics: true,
ShowQuery: true,
},
EDNS: util.EDNS{
EnableEDNS: true,
DNSSEC: true,
Cookie: true,
Expire: true,
Nsid: true,
Request: util.Request{
Server: "dns.adguard.com",
Port: 853,
Type: dns.TypeA,
Name: "example.com.",
Retries: 3,
},
Display: util.Display{
Comments: true,
Question: true,
Opt: true,
Answer: true,
Authority: true,
Additional: true,
Statistics: true,
ShowQuery: true,
},
EDNS: util.EDNS{
EnableEDNS: true,
DNSSEC: true,
Cookie: true,
Expire: true,
Nsid: true,
},
},
},
}
for _, opt := range in {
opt := opt
for _, test := range tests {
test := test
t.Run("", func(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
res, err := query.CreateQuery(opt)
res, err := query.CreateQuery(test.opts)
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
str, err := query.PrintSpecial(res, opt)
str, err := query.PrintSpecial(res, test.opts)
assert.NilError(t, err)
assert.Assert(t, str != "")
str, err = query.ToString(res, opt)
str, err = query.ToString(res, test.opts)
assert.NilError(t, err)
assert.Assert(t, str != "")
})

View file

@ -15,42 +15,49 @@ func TestDNSCrypt(t *testing.T) {
t.Parallel()
tests := []struct {
opt util.Options
name string
opts util.Options
}{
{
"Valid",
util.Options{
Logger: util.InitLogger(0),
DNSCrypt: true,
Request: util.Request{
Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
Type: dns.TypeA,
Name: "example.com.",
Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
Type: dns.TypeA,
Name: "example.com.",
Retries: 3,
},
},
},
{
"Valid (TCP)",
util.Options{
Logger: util.InitLogger(0),
DNSCrypt: true,
TCP: true,
IPv4: true,
Request: util.Request{
Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
Type: dns.TypeAAAA,
Name: "example.com.",
Server: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
Type: dns.TypeAAAA,
Name: "example.com.",
Retries: 3,
},
},
},
{
"Invalid",
util.Options{
Logger: util.InitLogger(0),
DNSCrypt: true,
TCP: true,
IPv4: true,
Request: util.Request{
Server: "QMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
Type: dns.TypeAAAA,
Name: "example.com.",
Server: "QMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
Type: dns.TypeAAAA,
Name: "example.com.",
Retries: 0,
},
},
},
@ -59,10 +66,10 @@ func TestDNSCrypt(t *testing.T) {
for _, test := range tests {
test := test
t.Run("", func(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
res, err := query.CreateQuery(test.opt)
res, err := query.CreateQuery(test.opts)
if err == nil {
assert.Assert(t, res != util.Response{})
} else {

View file

@ -52,7 +52,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
}
if res.StatusCode != http.StatusOK {
return util.Response{}, &errHTTPStatus{res.StatusCode}
return util.Response{}, &ErrHTTPStatus{res.StatusCode}
}
resolver.opts.Logger.Debug("https: reading response")
@ -79,10 +79,11 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (util.Response, error) {
return resp, nil
}
type errHTTPStatus struct {
// ErrHTTPStatus is returned when DoH returns a bad status code.
type ErrHTTPStatus struct {
code int
}
func (e *errHTTPStatus) Error() string {
func (e *ErrHTTPStatus) Error() string {
return fmt.Sprintf("doh server responded with HTTP %d", e.code)
}

View file

@ -3,6 +3,7 @@
package resolvers_test
import (
"errors"
"testing"
"git.froth.zone/sam/awl/pkg/resolvers"
@ -11,105 +12,87 @@ import (
"gotest.tools/v3/assert"
)
func TestResolveHTTPS(t *testing.T) {
func TestHTTPS(t *testing.T) {
t.Parallel()
var err error
opts := util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "https://dns9.quad9.net/dns-query",
Type: dns.TypeA,
Name: "git.froth.zone.",
tests := []struct {
name string
opts util.Options
}{
{
"Good",
util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "https://dns9.quad9.net/dns-query",
Type: dns.TypeA,
Name: "git.froth.zone.",
Retries: 3,
},
},
},
{
"404",
util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "https://dns9.quad9.net/dns",
Type: dns.TypeA,
Name: "git.froth.zone.",
},
},
},
{
"Bad request domain",
util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "dns9.quad9.net/dns-query",
Type: dns.TypeA,
Name: "git.froth.zone",
},
},
},
{
"Bad server domain",
util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "dns9..quad9.net/dns-query",
Type: dns.TypeA,
Name: "git.froth.zone.",
},
},
},
}
// testCase := util.Request{Server: "https://dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."}
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(opts.Request.Name, opts.Request.Type)
// msg = msg.SetQuestion(testCase.Name, testCase.Type)
res, err := resolver.LookUp(msg)
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
}
for _, test := range tests {
test := test
func Test2ResolveHTTPS(t *testing.T) {
t.Parallel()
t.Run(test.name, func(t *testing.T) {
t.Parallel()
opts := util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{Server: "dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."},
resolver, err := resolvers.LoadResolver(test.opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(test.opts.Request.Name, test.opts.Request.Type)
// msg = msg.SetQuestion(testCase.Name, testCase.Type)
res, err := resolver.LookUp(msg)
if err == nil {
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
} else {
if errors.Is(err, &resolvers.ErrHTTPStatus{}) {
assert.ErrorContains(t, err, "404")
}
assert.Equal(t, res, util.Response{})
}
})
}
var err error
testCase := util.Request{Type: dns.TypeA, Name: "git.froth.zone"}
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(testCase.Name, testCase.Type)
// msg = msg.SetQuestion(testCase.Name, testCase.Type)
res, err := resolver.LookUp(msg)
assert.ErrorContains(t, err, "fully qualified")
assert.Equal(t, res, util.Response{})
}
func Test3ResolveHTTPS(t *testing.T) {
t.Parallel()
opts := util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{Server: "dns9..quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."},
}
var err error
// testCase :=
// if the domain is not canonical, make it canonical
// if !strings.HasSuffix(testCase.Name, ".") {
// testCase.Name = fmt.Sprintf("%s.", testCase.Name)
// }
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
// msg.SetQuestion(testCase.Name, testCase.Type)
// msg = msg.SetQuestion(testCase.Name, testCase.Type)
res, err := resolver.LookUp(msg)
assert.ErrorContains(t, err, "doh: HTTP request")
assert.Equal(t, res, util.Response{})
}
func Test404ResolveHTTPS(t *testing.T) {
t.Parallel()
var err error
opts := util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "https://dns9.quad9.net/dns",
Type: dns.TypeA,
Name: "git.froth.zone.",
},
}
// testCase := util.Request{Server: "https://dns9.quad9.net/dns-query", Type: dns.TypeA, Name: "git.froth.zone."}
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(opts.Request.Name, opts.Request.Type)
// msg = msg.SetQuestion(testCase.Name, testCase.Type)
res, err := resolver.LookUp(msg)
assert.ErrorContains(t, err, "404")
assert.Equal(t, res, util.Response{})
}

View file

@ -3,10 +3,6 @@
package resolvers_test
import (
"fmt"
"net"
"strconv"
"strings"
"testing"
"time"
@ -19,72 +15,90 @@ import (
func TestQuic(t *testing.T) {
t.Parallel()
opts := util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{Server: "dns.adguard.com", Port: 853},
tests := []struct {
name string
opts util.Options
}{
{
"Valid",
util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "dns.adguard.com",
Type: dns.TypeNS,
Port: 853,
Timeout: 750 * time.Millisecond,
Retries: 3,
},
},
},
{
"Bad domain",
util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "dns.//./,,adguard\a.com",
Port: 853,
Type: dns.TypeA,
Name: "git.froth.zone",
Timeout: 100 * time.Millisecond,
Retries: 0,
},
},
},
{
"Not canonical",
util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "dns.adguard.com",
Port: 853,
Type: dns.TypeA,
Name: "git.froth.zone",
Timeout: 100 * time.Millisecond,
Retries: 0,
},
},
},
{
"Invalid query domain",
util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
Server: "example.com",
Port: 853,
Type: dns.TypeA,
Name: "git.froth.zone",
Timeout: 10 * time.Millisecond,
},
},
},
}
testCase := util.Request{Server: "dns.//./,,adguard.com", Type: dns.TypeA, Name: "git.froth.zone"}
testCase2 := util.Request{Server: "dns.adguard.com", Type: dns.TypeA, Name: "git.froth.zone"}
var testCases []util.Request
for _, test := range tests {
test := test
testCases = append(testCases, testCase)
testCases = append(testCases, testCase2)
t.Run(test.name, func(t *testing.T) {
t.Parallel()
for i := range testCases {
switch i {
case 0:
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(testCase.Name, ".") {
testCases[i].Name = fmt.Sprintf("%s.", testCases[i].Name)
}
msg := new(dns.Msg)
msg.SetQuestion(testCase.Name, testCase.Type)
// msg = msg.SetQuestion(testCase.Name, testCase.Type)
res, err := resolver.LookUp(msg)
assert.ErrorContains(t, err, "fully qualified")
assert.Equal(t, res, util.Response{})
case 1:
resolver, err := resolvers.LoadResolver(opts)
resolver, err := resolvers.LoadResolver(test.opts)
assert.NilError(t, err)
testCase2.Server = net.JoinHostPort(testCase2.Server, strconv.Itoa(opts.Request.Port))
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(testCase2.Name, ".") {
testCase2.Name = fmt.Sprintf("%s.", testCase2.Name)
}
msg := new(dns.Msg)
msg.SetQuestion(testCase2.Name, testCase2.Type)
msg.SetQuestion(test.opts.Request.Name, test.opts.Request.Type)
res, err := resolver.LookUp(msg)
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
}
if err == nil {
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
} else {
assert.Assert(t, res == util.Response{})
}
})
}
}
func TestInvalidQuic(t *testing.T) {
t.Parallel()
opts := util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{Server: "example.com", Port: 853, Type: dns.TypeA, Name: "git.froth.zone", Timeout: 10 * time.Millisecond},
}
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(opts.Request.Name, opts.Request.Type)
res, err := resolver.LookUp(msg)
assert.ErrorContains(t, err, "timeout")
assert.Equal(t, res, util.Response{})
}

View file

@ -3,11 +3,11 @@
package resolvers_test
import (
"os"
"testing"
"time"
"git.froth.zone/sam/awl/pkg/query"
"git.froth.zone/sam/awl/pkg/resolvers"
"git.froth.zone/sam/awl/pkg/util"
"github.com/miekg/dns"
"gotest.tools/v3/assert"
@ -16,90 +16,95 @@ import (
func TestResolve(t *testing.T) {
t.Parallel()
opts := util.Options{
Logger: util.InitLogger(0),
Request: util.Request{
Server: "8.8.4.1",
Port: 1,
Type: dns.TypeA,
Name: "example.com.",
Timeout: time.Second / 2,
Retries: 0,
},
}
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(opts.Request.Name, opts.Request.Type)
_, err = resolver.LookUp(msg)
assert.ErrorContains(t, err, "timeout")
}
func TestTruncate(t *testing.T) {
t.Parallel()
opts := util.Options{
Logger: util.InitLogger(0),
IPv4: true,
Request: util.Request{
Server: "madns.binarystar.systems",
Port: 5301,
Type: dns.TypeTXT,
Name: "limit.txt.example.",
},
}
resolver, err := resolvers.LoadResolver(opts)
assert.NilError(t, err)
msg := new(dns.Msg)
msg.SetQuestion(opts.Request.Name, opts.Request.Type)
res, err := resolver.LookUp(msg)
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
}
func TestResolveAgain(t *testing.T) {
t.Parallel()
tests := []struct {
opt util.Options
name string
opts util.Options
}{
{
"UDP",
util.Options{
Logger: util.InitLogger(0),
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeAAAA,
Name: "example.com.",
Retries: 3,
},
},
},
{
"UDP (Bad Cookie)",
util.Options{
Logger: util.InitLogger(0),
BadCookie: false,
Request: util.Request{
Server: "b.root-servers.net",
Port: 53,
Type: dns.TypeNS,
Name: "example.com.",
Retries: 3,
},
EDNS: util.EDNS{
EnableEDNS: true,
Cookie: true,
},
},
},
{
"UDP (Truncated)",
util.Options{
Logger: util.InitLogger(0),
IPv4: true,
Request: util.Request{
Server: "madns.binarystar.systems",
Port: 5301,
Type: dns.TypeTXT,
Name: "limit.txt.example.",
Retries: 3,
},
},
},
{
"TCP",
util.Options{
Logger: util.InitLogger(0),
TCP: true,
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeA,
Name: "example.com.",
},
},
},
{
util.Options{
Logger: util.InitLogger(0),
Request: util.Request{
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeAAAA,
Name: "example.com.",
Server: "8.8.4.4",
Port: 53,
Type: dns.TypeA,
Name: "example.com.",
Retries: 3,
},
},
},
{
"TLS",
util.Options{
Logger: util.InitLogger(0),
TLS: true,
Request: util.Request{
Server: "dns.google",
Port: 853,
Type: dns.TypeAAAA,
Name: "example.com.",
Server: "dns.google",
Port: 853,
Type: dns.TypeAAAA,
Name: "example.com.",
Retries: 3,
},
},
},
{
"Timeout",
util.Options{
Logger: util.InitLogger(0),
Request: util.Request{
Server: "8.8.4.1",
Port: 1,
Type: dns.TypeA,
Name: "example.com.",
Timeout: time.Millisecond * 100,
Retries: 0,
},
},
},
@ -108,11 +113,16 @@ func TestResolveAgain(t *testing.T) {
for _, test := range tests {
test := test
t.Run("", func(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
res, err := query.CreateQuery(test.opt)
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
res, err := query.CreateQuery(test.opts)
if err == nil {
assert.NilError(t, err)
assert.Assert(t, res != util.Response{})
} else {
assert.ErrorIs(t, err, os.ErrDeadlineExceeded)
}
})
}
}

View file

@ -17,21 +17,20 @@ func TestSubnet(t *testing.T) {
"::0/0",
"0",
"127.0.0.1/32",
"Invalid",
}
for _, test := range subnet {
test := test
t.Run(test, func(t *testing.T) {
t.Parallel()
err := util.ParseSubnet(test, new(util.Options))
assert.NilError(t, err)
if err != nil {
assert.ErrorContains(t, err, "invalid CIDR address")
} else {
assert.NilError(t, err)
}
})
}
}
func TestInvalidSub(t *testing.T) {
t.Parallel()
err := util.ParseSubnet("1", new(util.Options))
assert.ErrorContains(t, err, "invalid CIDR address")
}

View file

@ -10,25 +10,43 @@ import (
"gotest.tools/v3/assert"
)
var (
PTR = dns.StringToType["PTR"]
NAPTR = dns.StringToType["NAPTR"]
)
func TestIPv4(t *testing.T) {
func TestPTR(t *testing.T) {
t.Parallel()
act, err := util.ReverseDNS("8.8.4.4", PTR)
assert.NilError(t, err)
assert.Equal(t, act, "4.4.8.8.in-addr.arpa.", "IPv4 reverse")
}
tests := []struct {
name string
in string
expected string
}{
{
"IPv4",
"8.8.4.4", "4.4.8.8.in-addr.arpa.",
},
{
"IPv6",
"2606:4700:4700::1111", "1.1.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.4.0.0.7.4.6.0.6.2.ip6.arpa.",
},
{
"Inavlid value",
"AAAAA", "",
},
}
func TestIPv6(t *testing.T) {
t.Parallel()
for _, test := range tests {
test := test
act, err := util.ReverseDNS("2606:4700:4700::1111", PTR)
assert.NilError(t, err)
assert.Equal(t, act, "1.1.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.4.0.0.7.4.6.0.6.2.ip6.arpa.", "IPv6 reverse")
t.Run(test.name, func(t *testing.T) {
t.Parallel()
act, err := util.ReverseDNS(test.in, dns.StringToType["PTR"])
if err == nil {
assert.NilError(t, err)
} else {
assert.ErrorContains(t, err, "unrecognized address")
}
assert.Equal(t, act, test.expected)
})
}
}
func TestNAPTR(t *testing.T) {
@ -48,23 +66,14 @@ func TestNAPTR(t *testing.T) {
test := test
t.Run(test.in, func(t *testing.T) {
t.Parallel()
act, err := util.ReverseDNS(test.in, NAPTR)
act, err := util.ReverseDNS(test.in, dns.StringToType["NAPTR"])
assert.NilError(t, err)
assert.Equal(t, test.want, act)
})
}
}
func TestInvalid(t *testing.T) {
t.Parallel()
_, err := util.ReverseDNS("AAAAA", 1)
assert.ErrorContains(t, err, "invalid value AAAAA given")
}
func TestInvalid2(t *testing.T) {
t.Parallel()
_, err := util.ReverseDNS("1.0", PTR)
assert.ErrorContains(t, err, "PTR reverse")
func TestInvalidAll(t *testing.T) {
_, err := util.ReverseDNS("q", 15236)
assert.ErrorContains(t, err, "invalid value")
}