diff --git a/cli.go b/cli.go index e668955..f52c383 100644 --- a/cli.go +++ b/cli.go @@ -16,6 +16,8 @@ package config import ( + "encoding" + "errors" "flag" "fmt" "io" @@ -26,6 +28,8 @@ import ( "unsafe" ) +var ErrInternalError = errors.New("internal error") + // anyValue wraps a reflect.Value object and implements flag.Value interface // the reflect.Value could be Bool, String, Int, Uint and Float type anyValue struct { @@ -85,25 +89,25 @@ func (av *anyValue) Set(value string) error { case reflect.String: av.any.SetString(value) case reflect.Float32: - return SetValueWithFloatX(av.any, value, Float32Size) + return setValueWithFloatX(av.any, value, Float32Size) case reflect.Float64: - return SetValueWithFloatX(av.any, value, Float64Size) + return setValueWithFloatX(av.any, value, Float64Size) case reflect.Int8: - return SetValueWithIntX(av.any, value, Int8Size) + return setValueWithIntX(av.any, value, Int8Size) case reflect.Int16: - return SetValueWithIntX(av.any, value, Int16Size) + return setValueWithIntX(av.any, value, Int16Size) case reflect.Int, reflect.Int32: - return SetValueWithIntX(av.any, value, Int32Size) + return setValueWithIntX(av.any, value, Int32Size) case reflect.Int64: - return SetValueWithIntX(av.any, value, Int64Size) + return setValueWithIntX(av.any, value, Int64Size) case reflect.Uint8: - return SetValueWithUintX(av.any, value, Uint8Size) + return setValueWithUintX(av.any, value, Uint8Size) case reflect.Uint16: - return SetValueWithUintX(av.any, value, Uint16Size) + return setValueWithUintX(av.any, value, Uint16Size) case reflect.Uint, reflect.Uint32: - return SetValueWithUintX(av.any, value, Uint32Size) + return setValueWithUintX(av.any, value, Uint32Size) case reflect.Uint64: - return SetValueWithUintX(av.any, value, Uint64Size) + return setValueWithUintX(av.any, value, Uint64Size) default: return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) } @@ -132,7 +136,7 @@ func (sv *sliceValue) Set(value string) error { sp = ":" } - return SetValueWithSlice(sv.value, value, sp) + return setValueWithSlice(sv.value, value, sp) } // errorHandling is a global flag.ErrorHandling @@ -291,6 +295,26 @@ func (c *Command) addFlag(value reflect.Value, field reflect.StructField) error return nil } + unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + valuePtr := reflect.New(value.Type()) + + if valuePtr.Type().Implements(unmarshalerType) { + c.FlagSet.Func(name, usage, func(s string) error { + if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok { + if err := decoder.UnmarshalText([]byte(s)); err != nil { + return fmt.Errorf("unmarshal text: %w", err) + } + + value.Set(valuePtr.Elem()) + + return nil + } + + return ErrInternalError + }, + ) + } + kind := value.Kind() switch kind { case reflect.Bool: diff --git a/config.go b/config.go index 57ce078..a99234a 100644 --- a/config.go +++ b/config.go @@ -102,51 +102,48 @@ func parseValue(value reflect.Value) error { } func setValue(value reflect.Value, defValue string, field reflect.StructField) error { + if setUnmarshalTextValue(value, defValue) { + return nil + } + + if ok, err := setDurationValue(value, defValue); ok { + return err + } + var err error - valueTypePkgPath := value.Type().PkgPath() - valueTypeName := value.Type().Name() - - if valueTypePkgPath == "time" && valueTypeName == "Duration" { - return SetValueWithDuration(value, defValue) - } - - if valueTypePkgPath == "log/slog" && valueTypeName == "Level" { - return SetValueSlogLevel(value, defValue) - } - switch value.Kind() { case reflect.Bool: - err = SetValueWithBool(value, defValue) + err = setValueWithBool(value, defValue) case reflect.String: value.SetString(defValue) case reflect.Int8: - err = SetValueWithIntX(value, defValue, Int8Size) + err = setValueWithIntX(value, defValue, Int8Size) case reflect.Int16: - err = SetValueWithIntX(value, defValue, Int16Size) + err = setValueWithIntX(value, defValue, Int16Size) case reflect.Int, reflect.Int32: - err = SetValueWithIntX(value, defValue, Int32Size) + err = setValueWithIntX(value, defValue, Int32Size) case reflect.Int64: - err = SetValueWithIntX(value, defValue, Int64Size) + err = setValueWithIntX(value, defValue, Int64Size) case reflect.Uint8: - err = SetValueWithUintX(value, defValue, Uint8Size) + err = setValueWithUintX(value, defValue, Uint8Size) case reflect.Uint16: - err = SetValueWithUintX(value, defValue, Uint16Size) + err = setValueWithUintX(value, defValue, Uint16Size) case reflect.Uint, reflect.Uint32: - err = SetValueWithUintX(value, defValue, Uint32Size) + err = setValueWithUintX(value, defValue, Uint32Size) case reflect.Uint64: - err = SetValueWithUintX(value, defValue, Uint64Size) + err = setValueWithUintX(value, defValue, Uint64Size) case reflect.Float32: - err = SetValueWithFloatX(value, defValue, Float32Size) + err = setValueWithFloatX(value, defValue, Float32Size) case reflect.Float64: - err = SetValueWithFloatX(value, defValue, Float64Size) + err = setValueWithFloatX(value, defValue, Float64Size) case reflect.Slice: sp, ok := field.Tag.Lookup("separator") if !ok { sp = ":" } - err = SetValueWithSlice(value, defValue, sp) + err = setValueWithSlice(value, defValue, sp) default: err = fmt.Errorf("%w: %s", errUnsupportedType, value.Kind().String()) } diff --git a/config_test.go b/config_test.go index 2bf96b4..9b7d3e0 100644 --- a/config_test.go +++ b/config_test.go @@ -16,9 +16,11 @@ package config import ( + "encoding" "log/slog" "os" "path/filepath" + "reflect" "runtime" "strconv" "testing" @@ -116,3 +118,18 @@ func TestYamlConfigFile(t *testing.T) { assert.Equal(DB_LOG_PATH, conf.Log.Path) assert.Equal(DB_LOG_LEVEL, conf.Log.Level) } + +func TestTextUnmarshal(t *testing.T) { + var level slog.Level + + assert := assert.New(t) + value := reflect.ValueOf(&level) + unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + assert.True(value.Type().Implements(unmarshalerType)) + + valueDecoder, ok := value.Interface().(encoding.TextUnmarshaler) + assert.True(ok) + assert.NotNil(valueDecoder) + assert.NoError(valueDecoder.UnmarshalText([]byte("INFO"))) + assert.Equal(slog.LevelInfo, level) +} diff --git a/env.go b/env.go index f9e6d04..190df1f 100644 --- a/env.go +++ b/env.go @@ -196,34 +196,46 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str return fmt.Errorf("%w: %s", errValueCannotBeChanged, field.Name) } + if setUnmarshalTextValue(value, envValue) { + return nil + } + + if ok, err := setDurationValue(value, envValue); ok { + return err + } + + return setSimpleEnvValue(value, field, envValue) +} + +func setSimpleEnvValue(value reflect.Value, field reflect.StructField, str string) error { var err error kind := value.Kind() switch kind { case reflect.Bool: - err = SetValueWithBool(value, envValue) + err = setValueWithBool(value, str) case reflect.String: - value.SetString(envValue) + value.SetString(str) case reflect.Int8: - err = SetValueWithIntX(value, envValue, Int8Size) + err = setValueWithIntX(value, str, Int8Size) case reflect.Int16: - err = SetValueWithIntX(value, envValue, Int16Size) + err = setValueWithIntX(value, str, Int16Size) case reflect.Int, reflect.Int32: - err = SetValueWithIntX(value, envValue, Int32Size) + err = setValueWithIntX(value, str, Int32Size) case reflect.Int64: - err = SetValueWithIntX(value, envValue, Int64Size) + err = setValueWithIntX(value, str, Int64Size) case reflect.Uint8: - err = SetValueWithUintX(value, envValue, Uint8Size) + err = setValueWithUintX(value, str, Uint8Size) case reflect.Uint16: - err = SetValueWithUintX(value, envValue, Uint16Size) + err = setValueWithUintX(value, str, Uint16Size) case reflect.Uint, reflect.Uint32: - err = SetValueWithUintX(value, envValue, Uint32Size) + err = setValueWithUintX(value, str, Uint32Size) case reflect.Uint64: - err = SetValueWithUintX(value, envValue, Uint64Size) + err = setValueWithUintX(value, str, Uint64Size) case reflect.Float32: - err = SetValueWithFloatX(value, envValue, Float32Size) + err = setValueWithFloatX(value, str, Float32Size) case reflect.Float64: - err = SetValueWithFloatX(value, envValue, Float64Size) + err = setValueWithFloatX(value, str, Float64Size) case reflect.Slice: sp, ok := field.Tag.Lookup("separator") @@ -231,17 +243,17 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str sp = ":" } - err = SetValueWithSlice(value, envValue, sp) + err = setValueWithSlice(value, str, sp) default: return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) } if err != nil { - return fmt.Errorf("%s: %w", field.Name, err) + err = fmt.Errorf("%s: %w", field.Name, err) } - return nil + return err } func usageFieldValueEnv(out io.Writer, field reflect.StructField, prefix string) error { diff --git a/utils.go b/utils.go index 529f1a2..c07a5bd 100644 --- a/utils.go +++ b/utils.go @@ -18,14 +18,13 @@ package config import ( "encoding" "fmt" - "log/slog" "reflect" "strconv" "strings" "time" ) -func SetValueWithBool(value reflect.Value, boolStr string) error { +func setValueWithBool(value reflect.Value, boolStr string) error { boolValue, err := strconv.ParseBool(boolStr) if err != nil { if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok { @@ -44,7 +43,7 @@ func SetValueWithBool(value reflect.Value, boolStr string) error { return nil } -func SetValueWithDuration(value reflect.Value, str string) error { +func setValueWithDuration(value reflect.Value, str string) error { d, err := time.ParseDuration(str) if err != nil { return fmt.Errorf("parse duration: %w", err) @@ -55,19 +54,7 @@ func SetValueWithDuration(value reflect.Value, str string) error { return nil } -func SetValueSlogLevel(value reflect.Value, str string) error { - var level slog.Level - - if err := level.UnmarshalText([]byte(str)); err != nil { - return fmt.Errorf("parse slog level: %w", err) - } - - value.Set(reflect.ValueOf(level)) - - return nil -} - -func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error { +func setValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error { floatValue, err := strconv.ParseFloat(floatStr, bitSize) if err != nil { return fmt.Errorf("parse float: %w", err) @@ -78,7 +65,7 @@ func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error return nil } -func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error { +func setValueWithIntX(value reflect.Value, intStr string, bitSize int) error { intValue, err := strconv.ParseInt(intStr, 10, bitSize) if err != nil { if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok { @@ -97,7 +84,7 @@ func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error { return nil } -func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error { +func setValueWithUintX(value reflect.Value, uintStr string, bitSize int) error { uintValue, err := strconv.ParseUint(uintStr, 10, bitSize) if err != nil { if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok { @@ -116,7 +103,7 @@ func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error { return nil } -func SetValueWithSlice(value reflect.Value, slice string, sep string) error { +func setValueWithSlice(value reflect.Value, slice string, sep string) error { data := strings.Split(slice, sep) size := len(data) @@ -124,36 +111,45 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error { slice := reflect.MakeSlice(value.Type(), size, size) for index := range size { - var err error - ele := slice.Index(index) + str := data[index] + + if ok, err := setDurationValue(ele, str); ok { + return err + } + + if setUnmarshalTextValue(ele, str) { + continue + } + + var err error kind := ele.Kind() switch kind { case reflect.Bool: - err = SetValueWithBool(ele, data[index]) + err = setValueWithBool(ele, str) case reflect.String: - ele.SetString(data[index]) + ele.SetString(str) case reflect.Uint8: - err = SetValueWithUintX(ele, data[index], Uint8Size) + err = setValueWithUintX(ele, str, Uint8Size) case reflect.Uint16: - err = SetValueWithUintX(ele, data[index], Uint16Size) + err = setValueWithUintX(ele, str, Uint16Size) case reflect.Uint, reflect.Uint32: - err = SetValueWithUintX(ele, data[index], Uint32Size) + err = setValueWithUintX(ele, str, Uint32Size) case reflect.Uint64: - err = SetValueWithUintX(ele, data[index], Uint64Size) + err = setValueWithUintX(ele, str, Uint64Size) case reflect.Int8: - err = SetValueWithIntX(ele, data[index], Int8Size) + err = setValueWithIntX(ele, str, Int8Size) case reflect.Int16: - err = SetValueWithIntX(ele, data[index], Int16Size) + err = setValueWithIntX(ele, str, Int16Size) case reflect.Int, reflect.Int32: - err = SetValueWithIntX(ele, data[index], Int32Size) + err = setValueWithIntX(ele, str, Int32Size) case reflect.Int64: - err = SetValueWithIntX(ele, data[index], Int64Size) + err = setValueWithIntX(ele, str, Int64Size) case reflect.Float32: - err = SetValueWithFloatX(ele, data[index], Float32Size) + err = setValueWithFloatX(ele, str, Float32Size) case reflect.Float64: - err = SetValueWithFloatX(ele, data[index], Float64Size) + err = setValueWithFloatX(ele, str, Float64Size) default: return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) } @@ -168,3 +164,33 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error { return nil } + +func setUnmarshalTextValue(value reflect.Value, str string) bool { + unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + valuePtr := reflect.New(value.Type()) + + if valuePtr.Type().Implements(unmarshalerType) { + if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok { + if err := decoder.UnmarshalText([]byte(str)); err != nil { + return false + } + + value.Set(valuePtr.Elem()) + + return true + } + } + + return false +} + +func setDurationValue(value reflect.Value, str string) (bool, error) { + valueTypePkgPath := value.Type().PkgPath() + valueTypeName := value.Type().Name() + + if valueTypePkgPath == "time" && valueTypeName == "Duration" { + return true, setValueWithDuration(value, str) + } + + return false, nil +} diff --git a/utils_test.go b/utils_test.go index 5395107..c2ef1fa 100644 --- a/utils_test.go +++ b/utils_test.go @@ -37,7 +37,7 @@ func TestSetValueWithBool(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("BoolValue") assert := assert.New(t) - assert.NoError(SetValueWithBool(v, "true")) + assert.NoError(setValueWithBool(v, "true")) assert.Equal(true, d.BoolValue) } @@ -46,7 +46,7 @@ func TestSetValueWithFloat32(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("Float32Value") assert := assert.New(t) - assert.NoError(SetValueWithFloatX(v, "123.456", 32)) + assert.NoError(setValueWithFloatX(v, "123.456", 32)) assert.Equal(float32(123.456), d.Float32Value) } @@ -55,7 +55,7 @@ func TestSetValueWithInt8(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("Int8Value") assert := assert.New(t) - assert.NoError(SetValueWithIntX(v, "10", 8)) + assert.NoError(setValueWithIntX(v, "10", 8)) assert.Equal(int8(10), d.Int8Value) } @@ -64,7 +64,7 @@ func TestSetValueWithInt(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("IntValue") assert := assert.New(t) - assert.NoError(SetValueWithIntX(v, "10000", 32)) + assert.NoError(setValueWithIntX(v, "10000", 32)) assert.Equal(10000, d.IntValue) } @@ -73,7 +73,7 @@ func TestSetValueWithUint16(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("Uint16Value") assert := assert.New(t) - assert.NoError(SetValueWithUintX(v, "100", 16)) + assert.NoError(setValueWithUintX(v, "100", 16)) assert.Equal(uint16(100), d.Uint16Value) } @@ -82,7 +82,7 @@ func TestSetValueWithUint(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("UintValue") assert := assert.New(t) - assert.NoError(SetValueWithUintX(v, "2000", 32)) + assert.NoError(setValueWithUintX(v, "2000", 32)) assert.Equal(uint(2000), d.UintValue) } @@ -91,7 +91,7 @@ func TestSetValueWithSlice(t *testing.T) { v := reflect.ValueOf(&d).Elem().FieldByName("Names") assert := assert.New(t) - assert.NoError(SetValueWithSlice(v, "xx:yy:zz", ":")) + assert.NoError(setValueWithSlice(v, "xx:yy:zz", ":")) assert.Equal(3, len(d.Names)) assert.Equal("xx", d.Names[0]) assert.Equal("yy", d.Names[1])