common/helpers: add an helper to set default values for inner structs

This commit is contained in:
Vincent Bernat
2023-04-23 13:16:44 +02:00
parent 85561c44f7
commit 4069ecd158
3 changed files with 100 additions and 37 deletions

View File

@@ -63,6 +63,50 @@ func MapStructureMatchName(mapKey, fieldName string) bool {
return key == field
}
// DefaultValuesUnmarshallerHook adds default values from the provided
// configuration. For each missing non-default key, it will add them.
func DefaultValuesUnmarshallerHook[Configuration any](defaultConfiguration Configuration) mapstructure.DecodeHookFunc {
return func(from, to reflect.Value) (interface{}, error) {
from = ElemOrIdentity(from)
to = ElemOrIdentity(to)
if to.Type() != reflect.TypeOf(defaultConfiguration) {
return from.Interface(), nil
}
if from.Kind() != reflect.Map {
return from.Interface(), nil
}
// Which field is not to the default value in the default configuration?
found := map[string]bool{}
defaultV := reflect.ValueOf(defaultConfiguration)
for i := 0; i < defaultV.NumField(); i++ {
if !defaultV.Field(i).IsZero() {
found[defaultV.Type().Field(i).Name] = false
}
}
mapKeys := from.MapKeys()
for _, key := range mapKeys {
var keyStr string
if ElemOrIdentity(key).Kind() == reflect.String {
keyStr = ElemOrIdentity(key).String()
} else {
continue
}
for fieldName := range found {
if MapStructureMatchName(keyStr, fieldName) {
found[fieldName] = true
}
}
}
for fieldName := range found {
if !found[fieldName] {
from.SetMapIndex(reflect.ValueOf(fieldName), defaultV.FieldByName(fieldName))
}
}
return from.Interface(), nil
}
}
// ParametrizedConfigurationUnmarshallerHook will help decode a configuration
// structure parametrized by a type by selecting the appropriate concrete type
// depending on the type contained in the source. We have two configuration

View File

@@ -68,6 +68,56 @@ func TestProtectedDecodeHook(t *testing.T) {
}
}
func TestDefaultValuesConfig(t *testing.T) {
type InnerConfiguration struct {
AA string
BB string
CC int
}
type OuterConfiguration struct {
DD []InnerConfiguration
}
RegisterMapstructureUnmarshallerHook(DefaultValuesUnmarshallerHook(InnerConfiguration{
BB: "hello",
CC: 10,
}))
TestConfigurationDecode(t, ConfigurationDecodeCases{
{
Initial: func() interface{} { return OuterConfiguration{} },
Configuration: func() interface{} {
return gin.H{
"dd": []gin.H{
{
"aa": "hello1",
"bb": "hello2",
"cc": 43,
},
{"cc": 44},
{"aa": "bye"},
},
}
},
Expected: OuterConfiguration{
DD: []InnerConfiguration{
{
AA: "hello1",
BB: "hello2",
CC: 43,
}, {
AA: "",
BB: "hello",
CC: 44,
}, {
AA: "bye",
BB: "hello",
CC: 10,
},
},
},
},
})
}
func TestParametrizedConfig(t *testing.T) {
type InnerConfigurationType1 struct {
CC string

View File

@@ -140,42 +140,11 @@ type NetworkSource struct {
Interval time.Duration `validate:"min=1m"`
}
// NetworkSourceUnmarshallerHook decodes network sources, setting default
// values.
func NetworkSourceUnmarshallerHook() mapstructure.DecodeHookFunc {
return func(from, to reflect.Value) (interface{}, error) {
from = helpers.ElemOrIdentity(from)
to = helpers.ElemOrIdentity(to)
if to.Type() != reflect.TypeOf(NetworkSource{}) {
return from.Interface(), nil
}
if from.Kind() != reflect.Map {
return from.Interface(), nil
}
var methodFound, timeoutFound bool
mapKeys := from.MapKeys()
for _, key := range mapKeys {
var keyStr string
if helpers.ElemOrIdentity(key).Kind() == reflect.String {
keyStr = helpers.ElemOrIdentity(key).String()
} else {
continue
}
if helpers.MapStructureMatchName(keyStr, "Method") {
methodFound = true
}
if helpers.MapStructureMatchName(keyStr, "Timeout") {
timeoutFound = true
}
}
if !methodFound {
from.SetMapIndex(reflect.ValueOf("method"), reflect.ValueOf("GET"))
}
if !timeoutFound {
from.SetMapIndex(reflect.ValueOf("timeout"), reflect.ValueOf("1m"))
}
return from.Interface(), nil
// DefaultNetworkSourceConfiguration is the default configuration for a network source.
func DefaultNetworkSourceConfiguration() NetworkSource {
return NetworkSource{
Method: "GET",
Timeout: time.Minute,
}
}
@@ -210,6 +179,6 @@ func (jq TransformQuery) MarshalText() ([]byte, error) {
func init() {
helpers.RegisterMapstructureUnmarshallerHook(helpers.SubnetMapUnmarshallerHook[NetworkAttributes]())
helpers.RegisterMapstructureUnmarshallerHook(NetworkAttributesUnmarshallerHook())
helpers.RegisterMapstructureUnmarshallerHook(NetworkSourceUnmarshallerHook())
helpers.RegisterMapstructureUnmarshallerHook(helpers.DefaultValuesUnmarshallerHook[NetworkSource](DefaultNetworkSourceConfiguration()))
helpers.RegisterSubnetMapValidation[NetworkAttributes]()
}