refactor: Make all calls to options pointers (#132)
All checks were successful
continuous-integration/drone/push Build is passing

Instead of copying the opts struct every time it gets passed around, it should be created once and passed through reference.

This should reduce memory utilization, unfortunately I cannot test it since this program runs so fast pprof won't report anything useful.

I think I found all of them 🙂

Co-authored-by: Sam Therapy <sam@samtherapy.net>
Reviewed-on: #132
Reviewed-by: grumbulon <grumbulon@grumbulon.xyz>
This commit is contained in:
Sam Therapy 2022-10-13 12:49:36 +00:00
parent e6a3d6040a
commit 81da49093d
Signed by: Froth Git
GPG key ID: 5D8CD75CC6B79913
16 changed files with 54 additions and 52 deletions

View file

@ -17,7 +17,7 @@ import (
// ParseCLI parses arguments given from the CLI and passes them into an `Options`
// struct.
func ParseCLI(args []string, version string) (util.Options, error) {
func ParseCLI(args []string, version string) (*util.Options, error) {
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
flagSet.Usage = func() {
@ -105,13 +105,13 @@ func ParseCLI(args []string, version string) (util.Options, error) {
// Parse the flags
if err := flagSet.Parse(args[1:]); err != nil {
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
}
// TODO: DRY, dumb dumb.
mbz, err := strconv.ParseInt(*mbzflag, 0, 16)
if err != nil {
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err)
return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err)
}
opts := util.Options{
@ -179,7 +179,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
// TODO: DRY
if *subnet != "" {
if err = util.ParseSubnet(*subnet, &opts); err != nil {
return opts, fmt.Errorf("%w", err)
return &opts, fmt.Errorf("%w", err)
}
}
@ -189,14 +189,14 @@ func ParseCLI(args []string, version string) (util.Options, error) {
if *versionFlag {
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
return opts, ErrNotError
return &opts, ErrNotError
}
// Parse all the arguments that don't start with - or --
// This includes the dig-style (+) options
err = ParseMiscArgs(flagSet.Args(), &opts)
if err != nil {
return opts, err
return &opts, err
}
opts.Logger.Info("Dig/Drill flags parsed")
@ -224,7 +224,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
opts.Logger.Info("Options fully populated")
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
return opts, nil
return &opts, nil
}
// ErrNotError is for returning not error.

View file

@ -28,7 +28,7 @@ func main() {
}
}
func run(args []string) (opts util.Options, code int, err error) {
func run(args []string) (opts *util.Options, code int, err error) {
opts, err = cli.ParseCLI(args, version)
if err != nil {
return opts, 1, fmt.Errorf("parse: %w", err)

View file

@ -10,8 +10,7 @@ import (
)
func TestRun(t *testing.T) {
t.Parallel()
// t.Parallel()
args := [][]string{
{"awl", "+yaml", "@1.1.1.1"},
{"awl", "+short", "@1.1.1.1"},
@ -21,7 +20,6 @@ func TestRun(t *testing.T) {
test := test
t.Run("", func(t *testing.T) {
t.Parallel()
_, code, err := run(test)
assert.NilError(t, err)
assert.Equal(t, code, 0)
@ -30,8 +28,7 @@ func TestRun(t *testing.T) {
}
func TestHelp(t *testing.T) {
t.Parallel()
// t.Parallel()
args := []string{"awl", "-h"}
_, code, err := run(args)

View file

@ -21,7 +21,7 @@ import (
// ToString turns the response into something that looks a lot like dig
//
// Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900
func ToString(res util.Response, opts util.Options) (string, error) {
func ToString(res util.Response, opts *util.Options) (string, error) {
if res.DNS == nil {
return "<nil> MsgHdr", errNoMessage
}
@ -146,7 +146,7 @@ func ToString(res util.Response, opts util.Options) (string, error) {
return s, nil
}
func serverExtra(opts util.Options) string {
func serverExtra(opts *util.Options) string {
// Add extra information to server string
var extra string
@ -167,7 +167,7 @@ func serverExtra(opts util.Options) string {
}
// stringParse edits the raw responses to user requests.
func stringParse(str string, isAns bool, opts util.Options) (string, error) {
func stringParse(str string, isAns bool, opts *util.Options) (string, error) {
split := strings.Split(str, "\t")
// Make edits if so requested
@ -220,7 +220,7 @@ func stringParse(str string, isAns bool, opts util.Options) (string, error) {
// PrintSpecial is for printing as JSON, XML or YAML.
// As of now JSON and XML use the stdlib version.
func PrintSpecial(res util.Response, opts util.Options) (string, error) {
func PrintSpecial(res util.Response, opts *util.Options) (string, error) {
formatted, err := MakePrintable(res, opts)
if err != nil {
return "", err
@ -252,7 +252,7 @@ func PrintSpecial(res util.Response, opts util.Options) (string, error) {
// MakePrintable takes a DNS message and makes it nicer to be printed as JSON,YAML,
// and XML. Little is changed beyond naming.
func MakePrintable(res util.Response, opts util.Options) (*Message, error) {
func MakePrintable(res util.Response, opts *util.Options) (*Message, error) {
var (
err error
msg = res.DNS

View file

@ -14,7 +14,7 @@ import (
func TestRealPrint(t *testing.T) {
t.Parallel()
opts := []util.Options{
opts := []*util.Options{
{
Logger: util.InitLogger(0),
@ -216,14 +216,14 @@ func TestRealPrint(t *testing.T) {
func TestBadFormat(t *testing.T) {
t.Parallel()
_, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, util.Options{})
_, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, new(util.Options))
assert.ErrorContains(t, err, "never happen")
}
func TestEmpty(t *testing.T) {
t.Parallel()
str, err := query.ToString(util.Response{}, util.Options{})
str, err := query.ToString(util.Response{}, new(util.Options))
assert.Error(t, err, "no message")
assert.Assert(t, str == "<nil> MsgHdr")

View file

@ -14,7 +14,7 @@ import (
// CreateQuery creates a DNS query from the options given.
// It sets query flags and EDNS flags from the respective options.
func CreateQuery(opts util.Options) (util.Response, error) {
func CreateQuery(opts *util.Options) (util.Response, error) {
req := new(dns.Msg)
req.SetQuestion(opts.Request.Name, opts.Request.Type)
req.Question[0].Qclass = opts.Request.Class

View file

@ -14,13 +14,14 @@ import (
func TestCreateQ(t *testing.T) {
t.Parallel()
//nolint:govet // I could not be assed to refactor this, and it is only for tests
tests := []struct {
name string
opts util.Options
opts *util.Options
}{
{
"1",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
HeaderFlags: util.HeaderFlags{
Z: true,
@ -59,7 +60,7 @@ func TestCreateQ(t *testing.T) {
},
{
"2",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
HeaderFlags: util.HeaderFlags{
Z: true,
@ -88,7 +89,7 @@ func TestCreateQ(t *testing.T) {
},
{
"3",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
JSON: true,
QUIC: true,

View file

@ -13,7 +13,7 @@ import (
// DNSCryptResolver is for making DNSCrypt queries.
type DNSCryptResolver struct {
opts util.Options
opts *util.Options
}
var _ Resolver = (*DNSCryptResolver)(nil)

View file

@ -16,13 +16,14 @@ import (
func TestDNSCrypt(t *testing.T) {
t.Parallel()
//nolint:govet // I could not be assed to refactor this, and it is only for tests
tests := []struct {
name string
opts util.Options
opts *util.Options
}{
{
"Valid",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
DNSCrypt: true,
Request: util.Request{
@ -35,7 +36,7 @@ func TestDNSCrypt(t *testing.T) {
},
{
"Valid (TCP)",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
DNSCrypt: true,
TCP: true,
@ -50,7 +51,7 @@ func TestDNSCrypt(t *testing.T) {
},
{
"Invalid",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
DNSCrypt: true,
TCP: true,

View file

@ -16,8 +16,8 @@ import (
// HTTPSResolver is for DNS-over-HTTPS queries.
type HTTPSResolver struct {
opts *util.Options
client http.Client
opts util.Options
}
var _ Resolver = (*HTTPSResolver)(nil)

View file

@ -16,13 +16,14 @@ import (
func TestHTTPS(t *testing.T) {
t.Parallel()
//nolint:govet // I could not be assed to refactor this, and it is only for tests
tests := []struct {
name string
opts util.Options
opts *util.Options
}{
{
"Good",
util.Options{
&util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
@ -35,7 +36,7 @@ func TestHTTPS(t *testing.T) {
},
{
"404",
util.Options{
&util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
@ -47,7 +48,7 @@ func TestHTTPS(t *testing.T) {
},
{
"Bad request domain",
util.Options{
&util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{
@ -59,7 +60,7 @@ func TestHTTPS(t *testing.T) {
},
{
"Bad server domain",
util.Options{
&util.Options{
HTTPS: true,
Logger: util.InitLogger(0),
Request: util.Request{

View file

@ -15,7 +15,7 @@ import (
// QUICResolver is for DNS-over-QUIC queries.
type QUICResolver struct {
opts util.Options
opts *util.Options
}
var _ Resolver = (*QUICResolver)(nil)

View file

@ -15,13 +15,14 @@ import (
func TestQuic(t *testing.T) {
t.Parallel()
//nolint:govet // I could not be assed to refactor this, and it is only for tests
tests := []struct {
name string
opts util.Options
opts *util.Options
}{
{
"Valid",
util.Options{
&util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
@ -35,7 +36,7 @@ func TestQuic(t *testing.T) {
},
{
"Bad domain",
util.Options{
&util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
@ -50,7 +51,7 @@ func TestQuic(t *testing.T) {
},
{
"Not canonical",
util.Options{
&util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{
@ -65,7 +66,7 @@ func TestQuic(t *testing.T) {
},
{
"Invalid query domain",
util.Options{
&util.Options{
QUIC: true,
Logger: util.InitLogger(0),
Request: util.Request{

View file

@ -13,7 +13,7 @@ import (
// StandardResolver is for UDP/TCP resolvers.
type StandardResolver struct {
opts util.Options
opts *util.Options
}
var _ Resolver = (*StandardResolver)(nil)

View file

@ -18,13 +18,14 @@ import (
func TestResolve(t *testing.T) {
t.Parallel()
//nolint:govet // I could not be assed to refactor this, and it is only for tests
tests := []struct {
name string
opts util.Options
opts *util.Options
}{
{
"UDP",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
Request: util.Request{
Server: "8.8.4.4",
@ -37,7 +38,7 @@ func TestResolve(t *testing.T) {
},
{
"UDP (Bad Cookie)",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
BadCookie: false,
Request: util.Request{
@ -55,7 +56,7 @@ func TestResolve(t *testing.T) {
},
{
"UDP (Truncated)",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
IPv4: true,
Request: util.Request{
@ -69,7 +70,7 @@ func TestResolve(t *testing.T) {
},
{
"TCP",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
TCP: true,
@ -84,7 +85,7 @@ func TestResolve(t *testing.T) {
},
{
"TLS",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
TLS: true,
Request: util.Request{
@ -98,7 +99,7 @@ func TestResolve(t *testing.T) {
},
{
"Timeout",
util.Options{
&util.Options{
Logger: util.InitLogger(0),
Request: util.Request{
Server: "8.8.4.1",

View file

@ -22,7 +22,7 @@ type Resolver interface {
}
// LoadResolver loads the respective resolver for performing a DNS query.
func LoadResolver(opts util.Options) (Resolver, error) {
func LoadResolver(opts *util.Options) (Resolver, error) {
switch {
case opts.HTTPS:
opts.Logger.Info("loading DNS-over-HTTPS resolver")