From 81da49093daace008c9076517e2217a50b1d0943 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 13 Oct 2022 12:49:36 +0000 Subject: [PATCH] refactor: Make all calls to options pointers (#132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of copying the opts struct every time it gets passed around, it should be created once and passed through reference. This should reduce memory utilization, unfortunately I cannot test it since this program runs so fast pprof won't report anything useful. I think I found all of them 🙂 Co-authored-by: Sam Therapy Reviewed-on: https://git.froth.zone/sam/awl/pulls/132 Reviewed-by: grumbulon --- cmd/cli.go | 14 +++++++------- main.go | 2 +- main_test.go | 7 ++----- pkg/query/print.go | 10 +++++----- pkg/query/print_test.go | 6 +++--- pkg/query/query.go | 2 +- pkg/query/query_test.go | 9 +++++---- pkg/resolvers/DNSCrypt.go | 2 +- pkg/resolvers/DNSCrypt_test.go | 9 +++++---- pkg/resolvers/HTTPS.go | 2 +- pkg/resolvers/HTTPS_test.go | 11 ++++++----- pkg/resolvers/QUIC.go | 2 +- pkg/resolvers/QUIC_test.go | 11 ++++++----- pkg/resolvers/general.go | 2 +- pkg/resolvers/general_test.go | 15 ++++++++------- pkg/resolvers/resolver.go | 2 +- 16 files changed, 54 insertions(+), 52 deletions(-) diff --git a/cmd/cli.go b/cmd/cli.go index 14a1161..8b084e5 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -17,7 +17,7 @@ import ( // ParseCLI parses arguments given from the CLI and passes them into an `Options` // struct. -func ParseCLI(args []string, version string) (util.Options, error) { +func ParseCLI(args []string, version string) (*util.Options, error) { flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError) flagSet.Usage = func() { @@ -105,13 +105,13 @@ func ParseCLI(args []string, version string) (util.Options, error) { // Parse the flags if err := flagSet.Parse(args[1:]); err != nil { - return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err) + return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err) } // TODO: DRY, dumb dumb. mbz, err := strconv.ParseInt(*mbzflag, 0, 16) if err != nil { - return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err) + return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err) } opts := util.Options{ @@ -179,7 +179,7 @@ func ParseCLI(args []string, version string) (util.Options, error) { // TODO: DRY if *subnet != "" { if err = util.ParseSubnet(*subnet, &opts); err != nil { - return opts, fmt.Errorf("%w", err) + return &opts, fmt.Errorf("%w", err) } } @@ -189,14 +189,14 @@ func ParseCLI(args []string, version string) (util.Options, error) { if *versionFlag { fmt.Printf("awl version %s, built with %s\n", version, runtime.Version()) - return opts, ErrNotError + return &opts, ErrNotError } // Parse all the arguments that don't start with - or -- // This includes the dig-style (+) options err = ParseMiscArgs(flagSet.Args(), &opts) if err != nil { - return opts, err + return &opts, err } opts.Logger.Info("Dig/Drill flags parsed") @@ -224,7 +224,7 @@ func ParseCLI(args []string, version string) (util.Options, error) { opts.Logger.Info("Options fully populated") opts.Logger.Debug(fmt.Sprintf("%+v", opts)) - return opts, nil + return &opts, nil } // ErrNotError is for returning not error. diff --git a/main.go b/main.go index 448b8a1..5874045 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,7 @@ func main() { } } -func run(args []string) (opts util.Options, code int, err error) { +func run(args []string) (opts *util.Options, code int, err error) { opts, err = cli.ParseCLI(args, version) if err != nil { return opts, 1, fmt.Errorf("parse: %w", err) diff --git a/main_test.go b/main_test.go index 58f8e07..949230c 100644 --- a/main_test.go +++ b/main_test.go @@ -10,8 +10,7 @@ import ( ) func TestRun(t *testing.T) { - t.Parallel() - + // t.Parallel() args := [][]string{ {"awl", "+yaml", "@1.1.1.1"}, {"awl", "+short", "@1.1.1.1"}, @@ -21,7 +20,6 @@ func TestRun(t *testing.T) { test := test t.Run("", func(t *testing.T) { - t.Parallel() _, code, err := run(test) assert.NilError(t, err) assert.Equal(t, code, 0) @@ -30,8 +28,7 @@ func TestRun(t *testing.T) { } func TestHelp(t *testing.T) { - t.Parallel() - + // t.Parallel() args := []string{"awl", "-h"} _, code, err := run(args) diff --git a/pkg/query/print.go b/pkg/query/print.go index d46d40d..084ad82 100644 --- a/pkg/query/print.go +++ b/pkg/query/print.go @@ -21,7 +21,7 @@ import ( // ToString turns the response into something that looks a lot like dig // // Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900 -func ToString(res util.Response, opts util.Options) (string, error) { +func ToString(res util.Response, opts *util.Options) (string, error) { if res.DNS == nil { return " MsgHdr", errNoMessage } @@ -146,7 +146,7 @@ func ToString(res util.Response, opts util.Options) (string, error) { return s, nil } -func serverExtra(opts util.Options) string { +func serverExtra(opts *util.Options) string { // Add extra information to server string var extra string @@ -167,7 +167,7 @@ func serverExtra(opts util.Options) string { } // stringParse edits the raw responses to user requests. -func stringParse(str string, isAns bool, opts util.Options) (string, error) { +func stringParse(str string, isAns bool, opts *util.Options) (string, error) { split := strings.Split(str, "\t") // Make edits if so requested @@ -220,7 +220,7 @@ func stringParse(str string, isAns bool, opts util.Options) (string, error) { // PrintSpecial is for printing as JSON, XML or YAML. // As of now JSON and XML use the stdlib version. -func PrintSpecial(res util.Response, opts util.Options) (string, error) { +func PrintSpecial(res util.Response, opts *util.Options) (string, error) { formatted, err := MakePrintable(res, opts) if err != nil { return "", err @@ -252,7 +252,7 @@ func PrintSpecial(res util.Response, opts util.Options) (string, error) { // MakePrintable takes a DNS message and makes it nicer to be printed as JSON,YAML, // and XML. Little is changed beyond naming. -func MakePrintable(res util.Response, opts util.Options) (*Message, error) { +func MakePrintable(res util.Response, opts *util.Options) (*Message, error) { var ( err error msg = res.DNS diff --git a/pkg/query/print_test.go b/pkg/query/print_test.go index 4b3c8b8..386c5ac 100644 --- a/pkg/query/print_test.go +++ b/pkg/query/print_test.go @@ -14,7 +14,7 @@ import ( func TestRealPrint(t *testing.T) { t.Parallel() - opts := []util.Options{ + opts := []*util.Options{ { Logger: util.InitLogger(0), @@ -216,14 +216,14 @@ func TestRealPrint(t *testing.T) { func TestBadFormat(t *testing.T) { t.Parallel() - _, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, util.Options{}) + _, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, new(util.Options)) assert.ErrorContains(t, err, "never happen") } func TestEmpty(t *testing.T) { t.Parallel() - str, err := query.ToString(util.Response{}, util.Options{}) + str, err := query.ToString(util.Response{}, new(util.Options)) assert.Error(t, err, "no message") assert.Assert(t, str == " MsgHdr") diff --git a/pkg/query/query.go b/pkg/query/query.go index 9ab22a7..f2b99ee 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -14,7 +14,7 @@ import ( // CreateQuery creates a DNS query from the options given. // It sets query flags and EDNS flags from the respective options. -func CreateQuery(opts util.Options) (util.Response, error) { +func CreateQuery(opts *util.Options) (util.Response, error) { req := new(dns.Msg) req.SetQuestion(opts.Request.Name, opts.Request.Type) req.Question[0].Qclass = opts.Request.Class diff --git a/pkg/query/query_test.go b/pkg/query/query_test.go index 8e6db55..22f5a0e 100644 --- a/pkg/query/query_test.go +++ b/pkg/query/query_test.go @@ -14,13 +14,14 @@ import ( func TestCreateQ(t *testing.T) { t.Parallel() + //nolint:govet // I could not be assed to refactor this, and it is only for tests tests := []struct { name string - opts util.Options + opts *util.Options }{ { "1", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), HeaderFlags: util.HeaderFlags{ Z: true, @@ -59,7 +60,7 @@ func TestCreateQ(t *testing.T) { }, { "2", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), HeaderFlags: util.HeaderFlags{ Z: true, @@ -88,7 +89,7 @@ func TestCreateQ(t *testing.T) { }, { "3", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), JSON: true, QUIC: true, diff --git a/pkg/resolvers/DNSCrypt.go b/pkg/resolvers/DNSCrypt.go index dca3531..9bff6bd 100644 --- a/pkg/resolvers/DNSCrypt.go +++ b/pkg/resolvers/DNSCrypt.go @@ -13,7 +13,7 @@ import ( // DNSCryptResolver is for making DNSCrypt queries. type DNSCryptResolver struct { - opts util.Options + opts *util.Options } var _ Resolver = (*DNSCryptResolver)(nil) diff --git a/pkg/resolvers/DNSCrypt_test.go b/pkg/resolvers/DNSCrypt_test.go index df1a731..85b69e6 100644 --- a/pkg/resolvers/DNSCrypt_test.go +++ b/pkg/resolvers/DNSCrypt_test.go @@ -16,13 +16,14 @@ import ( func TestDNSCrypt(t *testing.T) { t.Parallel() + //nolint:govet // I could not be assed to refactor this, and it is only for tests tests := []struct { name string - opts util.Options + opts *util.Options }{ { "Valid", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, Request: util.Request{ @@ -35,7 +36,7 @@ func TestDNSCrypt(t *testing.T) { }, { "Valid (TCP)", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, TCP: true, @@ -50,7 +51,7 @@ func TestDNSCrypt(t *testing.T) { }, { "Invalid", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), DNSCrypt: true, TCP: true, diff --git a/pkg/resolvers/HTTPS.go b/pkg/resolvers/HTTPS.go index b9994b3..d10ea06 100644 --- a/pkg/resolvers/HTTPS.go +++ b/pkg/resolvers/HTTPS.go @@ -16,8 +16,8 @@ import ( // HTTPSResolver is for DNS-over-HTTPS queries. type HTTPSResolver struct { + opts *util.Options client http.Client - opts util.Options } var _ Resolver = (*HTTPSResolver)(nil) diff --git a/pkg/resolvers/HTTPS_test.go b/pkg/resolvers/HTTPS_test.go index 5505c58..75d5d95 100644 --- a/pkg/resolvers/HTTPS_test.go +++ b/pkg/resolvers/HTTPS_test.go @@ -16,13 +16,14 @@ import ( func TestHTTPS(t *testing.T) { t.Parallel() + //nolint:govet // I could not be assed to refactor this, and it is only for tests tests := []struct { name string - opts util.Options + opts *util.Options }{ { "Good", - util.Options{ + &util.Options{ HTTPS: true, Logger: util.InitLogger(0), Request: util.Request{ @@ -35,7 +36,7 @@ func TestHTTPS(t *testing.T) { }, { "404", - util.Options{ + &util.Options{ HTTPS: true, Logger: util.InitLogger(0), Request: util.Request{ @@ -47,7 +48,7 @@ func TestHTTPS(t *testing.T) { }, { "Bad request domain", - util.Options{ + &util.Options{ HTTPS: true, Logger: util.InitLogger(0), Request: util.Request{ @@ -59,7 +60,7 @@ func TestHTTPS(t *testing.T) { }, { "Bad server domain", - util.Options{ + &util.Options{ HTTPS: true, Logger: util.InitLogger(0), Request: util.Request{ diff --git a/pkg/resolvers/QUIC.go b/pkg/resolvers/QUIC.go index 60999be..6fd059c 100644 --- a/pkg/resolvers/QUIC.go +++ b/pkg/resolvers/QUIC.go @@ -15,7 +15,7 @@ import ( // QUICResolver is for DNS-over-QUIC queries. type QUICResolver struct { - opts util.Options + opts *util.Options } var _ Resolver = (*QUICResolver)(nil) diff --git a/pkg/resolvers/QUIC_test.go b/pkg/resolvers/QUIC_test.go index 029df11..137f979 100644 --- a/pkg/resolvers/QUIC_test.go +++ b/pkg/resolvers/QUIC_test.go @@ -15,13 +15,14 @@ import ( func TestQuic(t *testing.T) { t.Parallel() + //nolint:govet // I could not be assed to refactor this, and it is only for tests tests := []struct { name string - opts util.Options + opts *util.Options }{ { "Valid", - util.Options{ + &util.Options{ QUIC: true, Logger: util.InitLogger(0), Request: util.Request{ @@ -35,7 +36,7 @@ func TestQuic(t *testing.T) { }, { "Bad domain", - util.Options{ + &util.Options{ QUIC: true, Logger: util.InitLogger(0), Request: util.Request{ @@ -50,7 +51,7 @@ func TestQuic(t *testing.T) { }, { "Not canonical", - util.Options{ + &util.Options{ QUIC: true, Logger: util.InitLogger(0), Request: util.Request{ @@ -65,7 +66,7 @@ func TestQuic(t *testing.T) { }, { "Invalid query domain", - util.Options{ + &util.Options{ QUIC: true, Logger: util.InitLogger(0), Request: util.Request{ diff --git a/pkg/resolvers/general.go b/pkg/resolvers/general.go index 1f94c82..e4a9435 100644 --- a/pkg/resolvers/general.go +++ b/pkg/resolvers/general.go @@ -13,7 +13,7 @@ import ( // StandardResolver is for UDP/TCP resolvers. type StandardResolver struct { - opts util.Options + opts *util.Options } var _ Resolver = (*StandardResolver)(nil) diff --git a/pkg/resolvers/general_test.go b/pkg/resolvers/general_test.go index af4b16a..0300404 100644 --- a/pkg/resolvers/general_test.go +++ b/pkg/resolvers/general_test.go @@ -18,13 +18,14 @@ import ( func TestResolve(t *testing.T) { t.Parallel() + //nolint:govet // I could not be assed to refactor this, and it is only for tests tests := []struct { name string - opts util.Options + opts *util.Options }{ { "UDP", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), Request: util.Request{ Server: "8.8.4.4", @@ -37,7 +38,7 @@ func TestResolve(t *testing.T) { }, { "UDP (Bad Cookie)", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), BadCookie: false, Request: util.Request{ @@ -55,7 +56,7 @@ func TestResolve(t *testing.T) { }, { "UDP (Truncated)", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), IPv4: true, Request: util.Request{ @@ -69,7 +70,7 @@ func TestResolve(t *testing.T) { }, { "TCP", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), TCP: true, @@ -84,7 +85,7 @@ func TestResolve(t *testing.T) { }, { "TLS", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), TLS: true, Request: util.Request{ @@ -98,7 +99,7 @@ func TestResolve(t *testing.T) { }, { "Timeout", - util.Options{ + &util.Options{ Logger: util.InitLogger(0), Request: util.Request{ Server: "8.8.4.1", diff --git a/pkg/resolvers/resolver.go b/pkg/resolvers/resolver.go index a600f20..1bc96de 100644 --- a/pkg/resolvers/resolver.go +++ b/pkg/resolvers/resolver.go @@ -22,7 +22,7 @@ type Resolver interface { } // LoadResolver loads the respective resolver for performing a DNS query. -func LoadResolver(opts util.Options) (Resolver, error) { +func LoadResolver(opts *util.Options) (Resolver, error) { switch { case opts.HTTPS: opts.Logger.Info("loading DNS-over-HTTPS resolver")