193 lines
4.6 KiB
Go
193 lines
4.6 KiB
Go
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package cli_test
|
|
|
|
import (
|
|
"testing"
|
|
|
|
cli "git.froth.zone/sam/awl/cmd"
|
|
"git.froth.zone/sam/awl/pkg/util"
|
|
"github.com/miekg/dns"
|
|
"gotest.tools/v3/assert"
|
|
)
|
|
|
|
func TestParseArgs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
args := []string{
|
|
"go.dev",
|
|
"AAAA",
|
|
"@1.1.1.1",
|
|
"+ignore",
|
|
}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
err := cli.ParseMiscArgs(args, opts)
|
|
assert.NilError(t, err)
|
|
assert.Equal(t, opts.Request.Name, "go.dev.")
|
|
assert.Equal(t, opts.Request.Type, dns.StringToType["AAAA"])
|
|
assert.Equal(t, opts.Request.Server, "1.1.1.1")
|
|
assert.Equal(t, opts.Truncate, true)
|
|
}
|
|
|
|
func TestParseNoInput(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
args := []string{}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
err := cli.ParseMiscArgs(args, opts)
|
|
assert.NilError(t, err)
|
|
assert.Equal(t, opts.Request.Name, ".")
|
|
assert.Equal(t, opts.Request.Type, dns.StringToType["NS"])
|
|
}
|
|
|
|
func TestParseA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
args := []string{
|
|
"golang.org.",
|
|
}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
err := cli.ParseMiscArgs(args, opts)
|
|
assert.NilError(t, err)
|
|
assert.Equal(t, opts.Request.Name, "golang.org.")
|
|
assert.Equal(t, opts.Request.Type, dns.StringToType["A"])
|
|
}
|
|
|
|
func TestParsePTR(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
args := []string{"8.8.8.8"}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
opts.Reverse = true
|
|
err := cli.ParseMiscArgs(args, opts)
|
|
assert.NilError(t, err)
|
|
assert.Equal(t, opts.Request.Type, dns.StringToType["PTR"])
|
|
}
|
|
|
|
func TestParseInvalidPTR(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
args := []string{"8.88.8"}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
opts.Reverse = true
|
|
err := cli.ParseMiscArgs(args, opts)
|
|
assert.ErrorContains(t, err, "unrecognized address")
|
|
}
|
|
|
|
func TestDefaultServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
in string
|
|
want string
|
|
}{
|
|
{"DNSCrypt", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"},
|
|
{"TLS", "dns.google"},
|
|
{"HTTPS", "https://dns.cloudflare.com"},
|
|
{"QUIC", "dns.adguard.com"},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
test := test
|
|
|
|
t.Run(test.in, func(t *testing.T) {
|
|
t.Parallel()
|
|
args := []string{}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
switch test.in {
|
|
case "DNSCrypt":
|
|
opts.DNSCrypt = true
|
|
case "TLS":
|
|
opts.TLS = true
|
|
case "HTTPS":
|
|
opts.HTTPS = true
|
|
case "QUIC":
|
|
opts.QUIC = true
|
|
}
|
|
err := cli.ParseMiscArgs(args, opts)
|
|
assert.NilError(t, err)
|
|
assert.Equal(t, opts.Request.Server, test.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFlagSetting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
in string
|
|
expected string
|
|
over string
|
|
}{
|
|
{"@sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", "DNSCrypt"},
|
|
{"@tls://dns.google", "dns.google", "TLS"},
|
|
{"@https://dns.cloudflare.com/dns-query", "https://dns.cloudflare.com/dns-query", "HTTPS"},
|
|
{"@quic://dns.adguard.com", "dns.adguard.com", "QUIC"},
|
|
{"@tcp://dns.froth.zone", "dns.froth.zone", "TCP"},
|
|
{"@udp://dns.example.com", "dns.example.com", "UDP"},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
test := test
|
|
|
|
t.Run(test.over, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
|
|
err := cli.ParseMiscArgs([]string{test.in}, opts)
|
|
assert.NilError(t, err)
|
|
switch test.over {
|
|
case "DNSCrypt":
|
|
assert.Assert(t, opts.DNSCrypt)
|
|
assert.Equal(t, opts.Request.Server, test.expected)
|
|
case "TLS":
|
|
assert.Assert(t, opts.TLS)
|
|
assert.Equal(t, opts.Request.Server, test.expected)
|
|
case "HTTPS":
|
|
assert.Assert(t, opts.HTTPS)
|
|
assert.Equal(t, opts.Request.Server, test.expected)
|
|
case "QUIC":
|
|
assert.Assert(t, opts.QUIC)
|
|
assert.Equal(t, opts.Request.Server, test.expected)
|
|
case "TCP":
|
|
assert.Assert(t, opts.TCP)
|
|
assert.Equal(t, opts.Request.Server, test.expected)
|
|
case "UDP":
|
|
assert.Assert(t, true)
|
|
assert.Equal(t, opts.Request.Server, test.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func FuzzParseArgs(f *testing.F) {
|
|
cases := []string{
|
|
"go.dev",
|
|
"AAAA",
|
|
"@1.1.1.1",
|
|
"+ignore",
|
|
"e",
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
f.Add(tc)
|
|
}
|
|
|
|
f.Fuzz(func(t *testing.T, arg string) {
|
|
// Get rid of outputs
|
|
|
|
args := []string{arg}
|
|
opts := new(util.Options)
|
|
opts.Logger = util.InitLogger(0)
|
|
//nolint:errcheck,gosec // Only make sure the program does not crash
|
|
cli.ParseMiscArgs(args, opts)
|
|
})
|
|
}
|