From cd535c915df4da2ee32414820acfba47f7082251 Mon Sep 17 00:00:00 2001 From: Vincent Bernat Date: Sun, 14 Aug 2022 00:54:33 +0200 Subject: [PATCH] common/helpers: make subnetmap work with struct as values The way it was converted from a mapstruct made it not possible to have struct as values. Fix that by checking if keys look like IP or not. --- common/helpers/subnetmap.go | 27 +++++++++--- common/helpers/subnetmap_test.go | 73 ++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/common/helpers/subnetmap.go b/common/helpers/subnetmap.go index 2bef3e0e..bd8f4fdf 100644 --- a/common/helpers/subnetmap.go +++ b/common/helpers/subnetmap.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "reflect" + "regexp" "strings" "github.com/kentik/patricia" @@ -77,6 +78,8 @@ func MustNewSubnetMap[V any](from map[string]V) *SubnetMap[V] { return trie } +var subnetLookAlikeRegex = regexp.MustCompile("^([a-fA-F:.0-9]+)(/([0-9]+))?$") + // SubnetMapUnmarshallerHook decodes SubnetMap and notably check that // valid networks are provided as key. It also accepts a single value // instead of a map for backward compatibility. @@ -87,7 +90,23 @@ func SubnetMapUnmarshallerHook[V any]() mapstructure.DecodeHookFunc { } output := map[string]interface{}{} var zero V + var plausibleSubnetMap bool if from.Kind() == reflect.Map { + // When we have a map, we check if all keys look like a subnet. + plausibleSubnetMap = true + for _, key := range from.MapKeys() { + key = ElemOrIdentity(key) + if key.Kind() != reflect.String { + plausibleSubnetMap = false + break + } + if !subnetLookAlikeRegex.MatchString(key.String()) { + plausibleSubnetMap = false + break + } + } + } + if plausibleSubnetMap { // First case, we have a map iter := from.MapRange() for i := 0; iter.Next(); i++ { @@ -130,11 +149,9 @@ func SubnetMapUnmarshallerHook[V any]() mapstructure.DecodeHookFunc { } output[key] = v.Interface() } - } else if from.Type() == reflect.TypeOf(zero) || from.Type().ConvertibleTo(reflect.TypeOf(zero)) { - // Second case, we have a single value - output["::/0"] = from.Interface() } else { - return from.Interface(), nil + // Second case, we have a single value and we let mapstructure handles it + output["::/0"] = from.Interface() } // We have to decode output map, then turn it into a SubnetMap[V] @@ -164,5 +181,5 @@ func (sm SubnetMap[V]) MarshalYAML() (interface{}, error) { func (sm SubnetMap[V]) String() string { out := sm.ToMap() - return fmt.Sprintf("%v", out) + return fmt.Sprintf("%+v", out) } diff --git a/common/helpers/subnetmap_test.go b/common/helpers/subnetmap_test.go index 1f8ee7b4..5cb66ca4 100644 --- a/common/helpers/subnetmap_test.go +++ b/common/helpers/subnetmap_test.go @@ -97,6 +97,10 @@ func TestSubnetMapUnmarshalHook(t *testing.T) { Description: "Invalid IP", Input: gin.H{"200.33.300.1": "customer"}, Error: true, + }, { + Description: "Random key", + Input: gin.H{"kfgdjgkfj": "customer"}, + Error: true, }, { Description: "Single value", Input: "customer", @@ -158,3 +162,72 @@ func TestSubnetMapUnmarshalHook(t *testing.T) { }) } } + +func TestSubnetMapUnmarshalHookWithMapValue(t *testing.T) { + type SomeStruct struct { + Blip string + Blop string + } + cases := []struct { + Description string + Input gin.H + Expected gin.H + }{ + { + Description: "single value", + Input: gin.H{ + "blip": "some", + "blop": "thing", + }, + Expected: gin.H{ + "::/0": gin.H{ + "Blip": "some", + "Blop": "thing", + }, + }, + }, { + Description: "proper map", + Input: gin.H{ + "::/0": gin.H{ + "blip": "some", + "blop": "thing", + }, + "203.0.113.14": gin.H{ + "blip": "other", + "blop": "stuff", + }, + }, + Expected: gin.H{ + "::/0": gin.H{ + "Blip": "some", + "Blop": "thing", + }, + "203.0.113.14/32": gin.H{ + "Blip": "other", + "Blop": "stuff", + }, + }, + }, + } + for _, tc := range cases { + t.Run(tc.Description, func(t *testing.T) { + var tree helpers.SubnetMap[SomeStruct] + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &tree, + ErrorUnused: true, + Metadata: nil, + DecodeHook: helpers.SubnetMapUnmarshallerHook[SomeStruct](), + }) + if err != nil { + t.Fatalf("NewDecoder() error:\n%+v", err) + } + err = decoder.Decode(tc.Input) + if err != nil { + t.Fatalf("Decode() error:\n%+v", err) + } + if diff := helpers.Diff(tree.ToMap(), tc.Expected); diff != "" { + t.Fatalf("Decode() (-got, +want):\n%s", diff) + } + }) + } +}