refactor: Make all calls to options pointers
Signed-off-by: Sam Therapy <sam@samtherapy.net>
This commit is contained in:
parent
d2c6ed317e
commit
2a61464e17
16 changed files with 54 additions and 52 deletions
14
cmd/cli.go
14
cmd/cli.go
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
// ParseCLI parses arguments given from the CLI and passes them into an `Options`
|
||||
// struct.
|
||||
func ParseCLI(args []string, version string) (util.Options, error) {
|
||||
func ParseCLI(args []string, version string) (*util.Options, error) {
|
||||
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
|
||||
|
||||
flagSet.Usage = func() {
|
||||
|
@ -105,13 +105,13 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
|
||||
// Parse the flags
|
||||
if err := flagSet.Parse(args[1:]); err != nil {
|
||||
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
|
||||
return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
|
||||
}
|
||||
|
||||
// TODO: DRY, dumb dumb.
|
||||
mbz, err := strconv.ParseInt(*mbzflag, 0, 16)
|
||||
if err != nil {
|
||||
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err)
|
||||
return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err)
|
||||
}
|
||||
|
||||
opts := util.Options{
|
||||
|
@ -179,7 +179,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
// TODO: DRY
|
||||
if *subnet != "" {
|
||||
if err = util.ParseSubnet(*subnet, &opts); err != nil {
|
||||
return opts, fmt.Errorf("%w", err)
|
||||
return &opts, fmt.Errorf("%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,14 +189,14 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
if *versionFlag {
|
||||
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
|
||||
|
||||
return opts, ErrNotError
|
||||
return &opts, ErrNotError
|
||||
}
|
||||
|
||||
// Parse all the arguments that don't start with - or --
|
||||
// This includes the dig-style (+) options
|
||||
err = ParseMiscArgs(flagSet.Args(), &opts)
|
||||
if err != nil {
|
||||
return opts, err
|
||||
return &opts, err
|
||||
}
|
||||
|
||||
opts.Logger.Info("Dig/Drill flags parsed")
|
||||
|
@ -224,7 +224,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
opts.Logger.Info("Options fully populated")
|
||||
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
|
||||
|
||||
return opts, nil
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// ErrNotError is for returning not error.
|
||||
|
|
2
main.go
2
main.go
|
@ -28,7 +28,7 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
func run(args []string) (opts util.Options, code int, err error) {
|
||||
func run(args []string) (opts *util.Options, code int, err error) {
|
||||
opts, err = cli.ParseCLI(args, version)
|
||||
if err != nil {
|
||||
return opts, 1, fmt.Errorf("parse: %w", err)
|
||||
|
|
|
@ -10,8 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// t.Parallel()
|
||||
args := [][]string{
|
||||
{"awl", "+yaml", "@1.1.1.1"},
|
||||
{"awl", "+short", "@1.1.1.1"},
|
||||
|
@ -21,7 +20,6 @@ func TestRun(t *testing.T) {
|
|||
test := test
|
||||
|
||||
t.Run("", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, code, err := run(test)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, code, 0)
|
||||
|
@ -30,8 +28,7 @@ func TestRun(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHelp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// t.Parallel()
|
||||
args := []string{"awl", "-h"}
|
||||
|
||||
_, code, err := run(args)
|
||||
|
|
|
@ -21,7 +21,7 @@ import (
|
|||
// ToString turns the response into something that looks a lot like dig
|
||||
//
|
||||
// Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900
|
||||
func ToString(res util.Response, opts util.Options) (string, error) {
|
||||
func ToString(res util.Response, opts *util.Options) (string, error) {
|
||||
if res.DNS == nil {
|
||||
return "<nil> MsgHdr", errNoMessage
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func ToString(res util.Response, opts util.Options) (string, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func serverExtra(opts util.Options) string {
|
||||
func serverExtra(opts *util.Options) string {
|
||||
// Add extra information to server string
|
||||
var extra string
|
||||
|
||||
|
@ -167,7 +167,7 @@ func serverExtra(opts util.Options) string {
|
|||
}
|
||||
|
||||
// stringParse edits the raw responses to user requests.
|
||||
func stringParse(str string, isAns bool, opts util.Options) (string, error) {
|
||||
func stringParse(str string, isAns bool, opts *util.Options) (string, error) {
|
||||
split := strings.Split(str, "\t")
|
||||
|
||||
// Make edits if so requested
|
||||
|
@ -220,7 +220,7 @@ func stringParse(str string, isAns bool, opts util.Options) (string, error) {
|
|||
|
||||
// PrintSpecial is for printing as JSON, XML or YAML.
|
||||
// As of now JSON and XML use the stdlib version.
|
||||
func PrintSpecial(res util.Response, opts util.Options) (string, error) {
|
||||
func PrintSpecial(res util.Response, opts *util.Options) (string, error) {
|
||||
formatted, err := MakePrintable(res, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -252,7 +252,7 @@ func PrintSpecial(res util.Response, opts util.Options) (string, error) {
|
|||
|
||||
// MakePrintable takes a DNS message and makes it nicer to be printed as JSON,YAML,
|
||||
// and XML. Little is changed beyond naming.
|
||||
func MakePrintable(res util.Response, opts util.Options) (*Message, error) {
|
||||
func MakePrintable(res util.Response, opts *util.Options) (*Message, error) {
|
||||
var (
|
||||
err error
|
||||
msg = res.DNS
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
func TestRealPrint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := []util.Options{
|
||||
opts := []*util.Options{
|
||||
{
|
||||
Logger: util.InitLogger(0),
|
||||
|
||||
|
@ -216,14 +216,14 @@ func TestRealPrint(t *testing.T) {
|
|||
func TestBadFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, util.Options{})
|
||||
_, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, new(util.Options))
|
||||
assert.ErrorContains(t, err, "never happen")
|
||||
}
|
||||
|
||||
func TestEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
str, err := query.ToString(util.Response{}, util.Options{})
|
||||
str, err := query.ToString(util.Response{}, new(util.Options))
|
||||
|
||||
assert.Error(t, err, "no message")
|
||||
assert.Assert(t, str == "<nil> MsgHdr")
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
// CreateQuery creates a DNS query from the options given.
|
||||
// It sets query flags and EDNS flags from the respective options.
|
||||
func CreateQuery(opts util.Options) (util.Response, error) {
|
||||
func CreateQuery(opts *util.Options) (util.Response, error) {
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion(opts.Request.Name, opts.Request.Type)
|
||||
req.Question[0].Qclass = opts.Request.Class
|
||||
|
|
|
@ -14,13 +14,14 @@ import (
|
|||
func TestCreateQ(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"1",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
HeaderFlags: util.HeaderFlags{
|
||||
Z: true,
|
||||
|
@ -59,7 +60,7 @@ func TestCreateQ(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"2",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
HeaderFlags: util.HeaderFlags{
|
||||
Z: true,
|
||||
|
@ -88,7 +89,7 @@ func TestCreateQ(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"3",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
JSON: true,
|
||||
QUIC: true,
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
// DNSCryptResolver is for making DNSCrypt queries.
|
||||
type DNSCryptResolver struct {
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*DNSCryptResolver)(nil)
|
||||
|
|
|
@ -16,13 +16,14 @@ import (
|
|||
func TestDNSCrypt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"Valid",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
DNSCrypt: true,
|
||||
Request: util.Request{
|
||||
|
@ -35,7 +36,7 @@ func TestDNSCrypt(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Valid (TCP)",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
DNSCrypt: true,
|
||||
TCP: true,
|
||||
|
@ -50,7 +51,7 @@ func TestDNSCrypt(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Invalid",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
DNSCrypt: true,
|
||||
TCP: true,
|
||||
|
|
|
@ -16,8 +16,8 @@ import (
|
|||
|
||||
// HTTPSResolver is for DNS-over-HTTPS queries.
|
||||
type HTTPSResolver struct {
|
||||
opts *util.Options
|
||||
client http.Client
|
||||
opts util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*HTTPSResolver)(nil)
|
||||
|
|
|
@ -16,13 +16,14 @@ import (
|
|||
func TestHTTPS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"Good",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -35,7 +36,7 @@ func TestHTTPS(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"404",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -47,7 +48,7 @@ func TestHTTPS(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Bad request domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -59,7 +60,7 @@ func TestHTTPS(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Bad server domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
// QUICResolver is for DNS-over-QUIC queries.
|
||||
type QUICResolver struct {
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*QUICResolver)(nil)
|
||||
|
|
|
@ -15,13 +15,14 @@ import (
|
|||
func TestQuic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"Valid",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -35,7 +36,7 @@ func TestQuic(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Bad domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -50,7 +51,7 @@ func TestQuic(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Not canonical",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -65,7 +66,7 @@ func TestQuic(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Invalid query domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
// StandardResolver is for UDP/TCP resolvers.
|
||||
type StandardResolver struct {
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*StandardResolver)(nil)
|
||||
|
|
|
@ -18,13 +18,14 @@ import (
|
|||
func TestResolve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"UDP",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
Server: "8.8.4.4",
|
||||
|
@ -37,7 +38,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"UDP (Bad Cookie)",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
BadCookie: false,
|
||||
Request: util.Request{
|
||||
|
@ -55,7 +56,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"UDP (Truncated)",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
IPv4: true,
|
||||
Request: util.Request{
|
||||
|
@ -69,7 +70,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"TCP",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
TCP: true,
|
||||
|
||||
|
@ -84,7 +85,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"TLS",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
TLS: true,
|
||||
Request: util.Request{
|
||||
|
@ -98,7 +99,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Timeout",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
Server: "8.8.4.1",
|
||||
|
|
|
@ -22,7 +22,7 @@ type Resolver interface {
|
|||
}
|
||||
|
||||
// LoadResolver loads the respective resolver for performing a DNS query.
|
||||
func LoadResolver(opts util.Options) (Resolver, error) {
|
||||
func LoadResolver(opts *util.Options) (Resolver, error) {
|
||||
switch {
|
||||
case opts.HTTPS:
|
||||
opts.Logger.Info("loading DNS-over-HTTPS resolver")
|
||||
|
|
Loading…
Reference in a new issue