refactor: Make all calls to options pointers (#132)
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
Instead of copying the opts struct every time it gets passed around, it should be created once and passed through reference. This should reduce memory utilization, unfortunately I cannot test it since this program runs so fast pprof won't report anything useful. I think I found all of them 🙂 Co-authored-by: Sam Therapy <sam@samtherapy.net> Reviewed-on: #132 Reviewed-by: grumbulon <grumbulon@grumbulon.xyz>
This commit is contained in:
parent
e6a3d6040a
commit
81da49093d
16 changed files with 54 additions and 52 deletions
14
cmd/cli.go
14
cmd/cli.go
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
// ParseCLI parses arguments given from the CLI and passes them into an `Options`
|
||||
// struct.
|
||||
func ParseCLI(args []string, version string) (util.Options, error) {
|
||||
func ParseCLI(args []string, version string) (*util.Options, error) {
|
||||
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
|
||||
|
||||
flagSet.Usage = func() {
|
||||
|
@ -105,13 +105,13 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
|
||||
// Parse the flags
|
||||
if err := flagSet.Parse(args[1:]); err != nil {
|
||||
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
|
||||
return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err)
|
||||
}
|
||||
|
||||
// TODO: DRY, dumb dumb.
|
||||
mbz, err := strconv.ParseInt(*mbzflag, 0, 16)
|
||||
if err != nil {
|
||||
return util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err)
|
||||
return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err)
|
||||
}
|
||||
|
||||
opts := util.Options{
|
||||
|
@ -179,7 +179,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
// TODO: DRY
|
||||
if *subnet != "" {
|
||||
if err = util.ParseSubnet(*subnet, &opts); err != nil {
|
||||
return opts, fmt.Errorf("%w", err)
|
||||
return &opts, fmt.Errorf("%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,14 +189,14 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
if *versionFlag {
|
||||
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
|
||||
|
||||
return opts, ErrNotError
|
||||
return &opts, ErrNotError
|
||||
}
|
||||
|
||||
// Parse all the arguments that don't start with - or --
|
||||
// This includes the dig-style (+) options
|
||||
err = ParseMiscArgs(flagSet.Args(), &opts)
|
||||
if err != nil {
|
||||
return opts, err
|
||||
return &opts, err
|
||||
}
|
||||
|
||||
opts.Logger.Info("Dig/Drill flags parsed")
|
||||
|
@ -224,7 +224,7 @@ func ParseCLI(args []string, version string) (util.Options, error) {
|
|||
opts.Logger.Info("Options fully populated")
|
||||
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
|
||||
|
||||
return opts, nil
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// ErrNotError is for returning not error.
|
||||
|
|
2
main.go
2
main.go
|
@ -28,7 +28,7 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
func run(args []string) (opts util.Options, code int, err error) {
|
||||
func run(args []string) (opts *util.Options, code int, err error) {
|
||||
opts, err = cli.ParseCLI(args, version)
|
||||
if err != nil {
|
||||
return opts, 1, fmt.Errorf("parse: %w", err)
|
||||
|
|
|
@ -10,8 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// t.Parallel()
|
||||
args := [][]string{
|
||||
{"awl", "+yaml", "@1.1.1.1"},
|
||||
{"awl", "+short", "@1.1.1.1"},
|
||||
|
@ -21,7 +20,6 @@ func TestRun(t *testing.T) {
|
|||
test := test
|
||||
|
||||
t.Run("", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, code, err := run(test)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, code, 0)
|
||||
|
@ -30,8 +28,7 @@ func TestRun(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHelp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// t.Parallel()
|
||||
args := []string{"awl", "-h"}
|
||||
|
||||
_, code, err := run(args)
|
||||
|
|
|
@ -21,7 +21,7 @@ import (
|
|||
// ToString turns the response into something that looks a lot like dig
|
||||
//
|
||||
// Much of this is taken from https://github.com/miekg/dns/blob/master/msg.go#L900
|
||||
func ToString(res util.Response, opts util.Options) (string, error) {
|
||||
func ToString(res util.Response, opts *util.Options) (string, error) {
|
||||
if res.DNS == nil {
|
||||
return "<nil> MsgHdr", errNoMessage
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func ToString(res util.Response, opts util.Options) (string, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func serverExtra(opts util.Options) string {
|
||||
func serverExtra(opts *util.Options) string {
|
||||
// Add extra information to server string
|
||||
var extra string
|
||||
|
||||
|
@ -167,7 +167,7 @@ func serverExtra(opts util.Options) string {
|
|||
}
|
||||
|
||||
// stringParse edits the raw responses to user requests.
|
||||
func stringParse(str string, isAns bool, opts util.Options) (string, error) {
|
||||
func stringParse(str string, isAns bool, opts *util.Options) (string, error) {
|
||||
split := strings.Split(str, "\t")
|
||||
|
||||
// Make edits if so requested
|
||||
|
@ -220,7 +220,7 @@ func stringParse(str string, isAns bool, opts util.Options) (string, error) {
|
|||
|
||||
// PrintSpecial is for printing as JSON, XML or YAML.
|
||||
// As of now JSON and XML use the stdlib version.
|
||||
func PrintSpecial(res util.Response, opts util.Options) (string, error) {
|
||||
func PrintSpecial(res util.Response, opts *util.Options) (string, error) {
|
||||
formatted, err := MakePrintable(res, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -252,7 +252,7 @@ func PrintSpecial(res util.Response, opts util.Options) (string, error) {
|
|||
|
||||
// MakePrintable takes a DNS message and makes it nicer to be printed as JSON,YAML,
|
||||
// and XML. Little is changed beyond naming.
|
||||
func MakePrintable(res util.Response, opts util.Options) (*Message, error) {
|
||||
func MakePrintable(res util.Response, opts *util.Options) (*Message, error) {
|
||||
var (
|
||||
err error
|
||||
msg = res.DNS
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
func TestRealPrint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := []util.Options{
|
||||
opts := []*util.Options{
|
||||
{
|
||||
Logger: util.InitLogger(0),
|
||||
|
||||
|
@ -216,14 +216,14 @@ func TestRealPrint(t *testing.T) {
|
|||
func TestBadFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, util.Options{})
|
||||
_, err := query.PrintSpecial(util.Response{DNS: new(dns.Msg)}, new(util.Options))
|
||||
assert.ErrorContains(t, err, "never happen")
|
||||
}
|
||||
|
||||
func TestEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
str, err := query.ToString(util.Response{}, util.Options{})
|
||||
str, err := query.ToString(util.Response{}, new(util.Options))
|
||||
|
||||
assert.Error(t, err, "no message")
|
||||
assert.Assert(t, str == "<nil> MsgHdr")
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
// CreateQuery creates a DNS query from the options given.
|
||||
// It sets query flags and EDNS flags from the respective options.
|
||||
func CreateQuery(opts util.Options) (util.Response, error) {
|
||||
func CreateQuery(opts *util.Options) (util.Response, error) {
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion(opts.Request.Name, opts.Request.Type)
|
||||
req.Question[0].Qclass = opts.Request.Class
|
||||
|
|
|
@ -14,13 +14,14 @@ import (
|
|||
func TestCreateQ(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"1",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
HeaderFlags: util.HeaderFlags{
|
||||
Z: true,
|
||||
|
@ -59,7 +60,7 @@ func TestCreateQ(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"2",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
HeaderFlags: util.HeaderFlags{
|
||||
Z: true,
|
||||
|
@ -88,7 +89,7 @@ func TestCreateQ(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"3",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
JSON: true,
|
||||
QUIC: true,
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
// DNSCryptResolver is for making DNSCrypt queries.
|
||||
type DNSCryptResolver struct {
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*DNSCryptResolver)(nil)
|
||||
|
|
|
@ -16,13 +16,14 @@ import (
|
|||
func TestDNSCrypt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"Valid",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
DNSCrypt: true,
|
||||
Request: util.Request{
|
||||
|
@ -35,7 +36,7 @@ func TestDNSCrypt(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Valid (TCP)",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
DNSCrypt: true,
|
||||
TCP: true,
|
||||
|
@ -50,7 +51,7 @@ func TestDNSCrypt(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Invalid",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
DNSCrypt: true,
|
||||
TCP: true,
|
||||
|
|
|
@ -16,8 +16,8 @@ import (
|
|||
|
||||
// HTTPSResolver is for DNS-over-HTTPS queries.
|
||||
type HTTPSResolver struct {
|
||||
opts *util.Options
|
||||
client http.Client
|
||||
opts util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*HTTPSResolver)(nil)
|
||||
|
|
|
@ -16,13 +16,14 @@ import (
|
|||
func TestHTTPS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"Good",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -35,7 +36,7 @@ func TestHTTPS(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"404",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -47,7 +48,7 @@ func TestHTTPS(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Bad request domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -59,7 +60,7 @@ func TestHTTPS(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Bad server domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
HTTPS: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
// QUICResolver is for DNS-over-QUIC queries.
|
||||
type QUICResolver struct {
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*QUICResolver)(nil)
|
||||
|
|
|
@ -15,13 +15,14 @@ import (
|
|||
func TestQuic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"Valid",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -35,7 +36,7 @@ func TestQuic(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Bad domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -50,7 +51,7 @@ func TestQuic(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Not canonical",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
@ -65,7 +66,7 @@ func TestQuic(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Invalid query domain",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
QUIC: true,
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
// StandardResolver is for UDP/TCP resolvers.
|
||||
type StandardResolver struct {
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}
|
||||
|
||||
var _ Resolver = (*StandardResolver)(nil)
|
||||
|
|
|
@ -18,13 +18,14 @@ import (
|
|||
func TestResolve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:govet // I could not be assed to refactor this, and it is only for tests
|
||||
tests := []struct {
|
||||
name string
|
||||
opts util.Options
|
||||
opts *util.Options
|
||||
}{
|
||||
{
|
||||
"UDP",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
Server: "8.8.4.4",
|
||||
|
@ -37,7 +38,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"UDP (Bad Cookie)",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
BadCookie: false,
|
||||
Request: util.Request{
|
||||
|
@ -55,7 +56,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"UDP (Truncated)",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
IPv4: true,
|
||||
Request: util.Request{
|
||||
|
@ -69,7 +70,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"TCP",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
TCP: true,
|
||||
|
||||
|
@ -84,7 +85,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"TLS",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
TLS: true,
|
||||
Request: util.Request{
|
||||
|
@ -98,7 +99,7 @@ func TestResolve(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"Timeout",
|
||||
util.Options{
|
||||
&util.Options{
|
||||
Logger: util.InitLogger(0),
|
||||
Request: util.Request{
|
||||
Server: "8.8.4.1",
|
||||
|
|
|
@ -22,7 +22,7 @@ type Resolver interface {
|
|||
}
|
||||
|
||||
// LoadResolver loads the respective resolver for performing a DNS query.
|
||||
func LoadResolver(opts util.Options) (Resolver, error) {
|
||||
func LoadResolver(opts *util.Options) (Resolver, error) {
|
||||
switch {
|
||||
case opts.HTTPS:
|
||||
opts.Logger.Info("loading DNS-over-HTTPS resolver")
|
||||
|
|
Loading…
Reference in a new issue