diff --git a/common/helpers/mapstructure.go b/common/helpers/mapstructure.go index ac836484..3a308e8f 100644 --- a/common/helpers/mapstructure.go +++ b/common/helpers/mapstructure.go @@ -63,6 +63,50 @@ func MapStructureMatchName(mapKey, fieldName string) bool { return key == field } +// DefaultValuesUnmarshallerHook adds default values from the provided +// configuration. For each missing non-default key, it will add them. +func DefaultValuesUnmarshallerHook[Configuration any](defaultConfiguration Configuration) mapstructure.DecodeHookFunc { + return func(from, to reflect.Value) (interface{}, error) { + from = ElemOrIdentity(from) + to = ElemOrIdentity(to) + if to.Type() != reflect.TypeOf(defaultConfiguration) { + return from.Interface(), nil + } + if from.Kind() != reflect.Map { + return from.Interface(), nil + } + + // Which field is not to the default value in the default configuration? + found := map[string]bool{} + defaultV := reflect.ValueOf(defaultConfiguration) + for i := 0; i < defaultV.NumField(); i++ { + if !defaultV.Field(i).IsZero() { + found[defaultV.Type().Field(i).Name] = false + } + } + mapKeys := from.MapKeys() + for _, key := range mapKeys { + var keyStr string + if ElemOrIdentity(key).Kind() == reflect.String { + keyStr = ElemOrIdentity(key).String() + } else { + continue + } + for fieldName := range found { + if MapStructureMatchName(keyStr, fieldName) { + found[fieldName] = true + } + } + } + for fieldName := range found { + if !found[fieldName] { + from.SetMapIndex(reflect.ValueOf(fieldName), defaultV.FieldByName(fieldName)) + } + } + return from.Interface(), nil + } +} + // ParametrizedConfigurationUnmarshallerHook will help decode a configuration // structure parametrized by a type by selecting the appropriate concrete type // depending on the type contained in the source. We have two configuration diff --git a/common/helpers/mapstructure_test.go b/common/helpers/mapstructure_test.go index 3bb402bd..30060191 100644 --- a/common/helpers/mapstructure_test.go +++ b/common/helpers/mapstructure_test.go @@ -68,6 +68,56 @@ func TestProtectedDecodeHook(t *testing.T) { } } +func TestDefaultValuesConfig(t *testing.T) { + type InnerConfiguration struct { + AA string + BB string + CC int + } + type OuterConfiguration struct { + DD []InnerConfiguration + } + RegisterMapstructureUnmarshallerHook(DefaultValuesUnmarshallerHook(InnerConfiguration{ + BB: "hello", + CC: 10, + })) + TestConfigurationDecode(t, ConfigurationDecodeCases{ + { + Initial: func() interface{} { return OuterConfiguration{} }, + Configuration: func() interface{} { + return gin.H{ + "dd": []gin.H{ + { + "aa": "hello1", + "bb": "hello2", + "cc": 43, + }, + {"cc": 44}, + {"aa": "bye"}, + }, + } + }, + Expected: OuterConfiguration{ + DD: []InnerConfiguration{ + { + AA: "hello1", + BB: "hello2", + CC: 43, + }, { + AA: "", + BB: "hello", + CC: 44, + }, { + AA: "bye", + BB: "hello", + CC: 10, + }, + }, + }, + }, + }) +} + func TestParametrizedConfig(t *testing.T) { type InnerConfigurationType1 struct { CC string diff --git a/orchestrator/clickhouse/config.go b/orchestrator/clickhouse/config.go index c66121c2..e7eff8a4 100644 --- a/orchestrator/clickhouse/config.go +++ b/orchestrator/clickhouse/config.go @@ -140,42 +140,11 @@ type NetworkSource struct { Interval time.Duration `validate:"min=1m"` } -// NetworkSourceUnmarshallerHook decodes network sources, setting default -// values. -func NetworkSourceUnmarshallerHook() mapstructure.DecodeHookFunc { - return func(from, to reflect.Value) (interface{}, error) { - from = helpers.ElemOrIdentity(from) - to = helpers.ElemOrIdentity(to) - if to.Type() != reflect.TypeOf(NetworkSource{}) { - return from.Interface(), nil - } - if from.Kind() != reflect.Map { - return from.Interface(), nil - } - - var methodFound, timeoutFound bool - mapKeys := from.MapKeys() - for _, key := range mapKeys { - var keyStr string - if helpers.ElemOrIdentity(key).Kind() == reflect.String { - keyStr = helpers.ElemOrIdentity(key).String() - } else { - continue - } - if helpers.MapStructureMatchName(keyStr, "Method") { - methodFound = true - } - if helpers.MapStructureMatchName(keyStr, "Timeout") { - timeoutFound = true - } - } - if !methodFound { - from.SetMapIndex(reflect.ValueOf("method"), reflect.ValueOf("GET")) - } - if !timeoutFound { - from.SetMapIndex(reflect.ValueOf("timeout"), reflect.ValueOf("1m")) - } - return from.Interface(), nil +// DefaultNetworkSourceConfiguration is the default configuration for a network source. +func DefaultNetworkSourceConfiguration() NetworkSource { + return NetworkSource{ + Method: "GET", + Timeout: time.Minute, } } @@ -210,6 +179,6 @@ func (jq TransformQuery) MarshalText() ([]byte, error) { func init() { helpers.RegisterMapstructureUnmarshallerHook(helpers.SubnetMapUnmarshallerHook[NetworkAttributes]()) helpers.RegisterMapstructureUnmarshallerHook(NetworkAttributesUnmarshallerHook()) - helpers.RegisterMapstructureUnmarshallerHook(NetworkSourceUnmarshallerHook()) + helpers.RegisterMapstructureUnmarshallerHook(helpers.DefaultValuesUnmarshallerHook[NetworkSource](DefaultNetworkSourceConfiguration())) helpers.RegisterSubnetMapValidation[NetworkAttributes]() }