Добавлена поддержка любых параметров, реализующих интерфейс encoding.TextUnmarshaler
All checks were successful
test / test (push) Successful in 58s
All checks were successful
test / test (push) Successful in 58s
This commit is contained in:
parent
2de4eecb68
commit
98be07131b
46
cli.go
46
cli.go
@ -16,6 +16,8 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -26,6 +28,8 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var ErrInternalError = errors.New("internal error")
|
||||
|
||||
// anyValue wraps a reflect.Value object and implements flag.Value interface
|
||||
// the reflect.Value could be Bool, String, Int, Uint and Float
|
||||
type anyValue struct {
|
||||
@ -85,25 +89,25 @@ func (av *anyValue) Set(value string) error {
|
||||
case reflect.String:
|
||||
av.any.SetString(value)
|
||||
case reflect.Float32:
|
||||
return SetValueWithFloatX(av.any, value, Float32Size)
|
||||
return setValueWithFloatX(av.any, value, Float32Size)
|
||||
case reflect.Float64:
|
||||
return SetValueWithFloatX(av.any, value, Float64Size)
|
||||
return setValueWithFloatX(av.any, value, Float64Size)
|
||||
case reflect.Int8:
|
||||
return SetValueWithIntX(av.any, value, Int8Size)
|
||||
return setValueWithIntX(av.any, value, Int8Size)
|
||||
case reflect.Int16:
|
||||
return SetValueWithIntX(av.any, value, Int16Size)
|
||||
return setValueWithIntX(av.any, value, Int16Size)
|
||||
case reflect.Int, reflect.Int32:
|
||||
return SetValueWithIntX(av.any, value, Int32Size)
|
||||
return setValueWithIntX(av.any, value, Int32Size)
|
||||
case reflect.Int64:
|
||||
return SetValueWithIntX(av.any, value, Int64Size)
|
||||
return setValueWithIntX(av.any, value, Int64Size)
|
||||
case reflect.Uint8:
|
||||
return SetValueWithUintX(av.any, value, Uint8Size)
|
||||
return setValueWithUintX(av.any, value, Uint8Size)
|
||||
case reflect.Uint16:
|
||||
return SetValueWithUintX(av.any, value, Uint16Size)
|
||||
return setValueWithUintX(av.any, value, Uint16Size)
|
||||
case reflect.Uint, reflect.Uint32:
|
||||
return SetValueWithUintX(av.any, value, Uint32Size)
|
||||
return setValueWithUintX(av.any, value, Uint32Size)
|
||||
case reflect.Uint64:
|
||||
return SetValueWithUintX(av.any, value, Uint64Size)
|
||||
return setValueWithUintX(av.any, value, Uint64Size)
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
|
||||
}
|
||||
@ -132,7 +136,7 @@ func (sv *sliceValue) Set(value string) error {
|
||||
sp = ":"
|
||||
}
|
||||
|
||||
return SetValueWithSlice(sv.value, value, sp)
|
||||
return setValueWithSlice(sv.value, value, sp)
|
||||
}
|
||||
|
||||
// errorHandling is a global flag.ErrorHandling
|
||||
@ -291,6 +295,26 @@ func (c *Command) addFlag(value reflect.Value, field reflect.StructField) error
|
||||
return nil
|
||||
}
|
||||
|
||||
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
valuePtr := reflect.New(value.Type())
|
||||
|
||||
if valuePtr.Type().Implements(unmarshalerType) {
|
||||
c.FlagSet.Func(name, usage, func(s string) error {
|
||||
if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok {
|
||||
if err := decoder.UnmarshalText([]byte(s)); err != nil {
|
||||
return fmt.Errorf("unmarshal text: %w", err)
|
||||
}
|
||||
|
||||
value.Set(valuePtr.Elem())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrInternalError
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
kind := value.Kind()
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
|
43
config.go
43
config.go
@ -102,51 +102,48 @@ func parseValue(value reflect.Value) error {
|
||||
}
|
||||
|
||||
func setValue(value reflect.Value, defValue string, field reflect.StructField) error {
|
||||
if setUnmarshalTextValue(value, defValue) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ok, err := setDurationValue(value, defValue); ok {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
valueTypePkgPath := value.Type().PkgPath()
|
||||
valueTypeName := value.Type().Name()
|
||||
|
||||
if valueTypePkgPath == "time" && valueTypeName == "Duration" {
|
||||
return SetValueWithDuration(value, defValue)
|
||||
}
|
||||
|
||||
if valueTypePkgPath == "log/slog" && valueTypeName == "Level" {
|
||||
return SetValueSlogLevel(value, defValue)
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
err = SetValueWithBool(value, defValue)
|
||||
err = setValueWithBool(value, defValue)
|
||||
case reflect.String:
|
||||
value.SetString(defValue)
|
||||
case reflect.Int8:
|
||||
err = SetValueWithIntX(value, defValue, Int8Size)
|
||||
err = setValueWithIntX(value, defValue, Int8Size)
|
||||
case reflect.Int16:
|
||||
err = SetValueWithIntX(value, defValue, Int16Size)
|
||||
err = setValueWithIntX(value, defValue, Int16Size)
|
||||
case reflect.Int, reflect.Int32:
|
||||
err = SetValueWithIntX(value, defValue, Int32Size)
|
||||
err = setValueWithIntX(value, defValue, Int32Size)
|
||||
case reflect.Int64:
|
||||
err = SetValueWithIntX(value, defValue, Int64Size)
|
||||
err = setValueWithIntX(value, defValue, Int64Size)
|
||||
case reflect.Uint8:
|
||||
err = SetValueWithUintX(value, defValue, Uint8Size)
|
||||
err = setValueWithUintX(value, defValue, Uint8Size)
|
||||
case reflect.Uint16:
|
||||
err = SetValueWithUintX(value, defValue, Uint16Size)
|
||||
err = setValueWithUintX(value, defValue, Uint16Size)
|
||||
case reflect.Uint, reflect.Uint32:
|
||||
err = SetValueWithUintX(value, defValue, Uint32Size)
|
||||
err = setValueWithUintX(value, defValue, Uint32Size)
|
||||
case reflect.Uint64:
|
||||
err = SetValueWithUintX(value, defValue, Uint64Size)
|
||||
err = setValueWithUintX(value, defValue, Uint64Size)
|
||||
case reflect.Float32:
|
||||
err = SetValueWithFloatX(value, defValue, Float32Size)
|
||||
err = setValueWithFloatX(value, defValue, Float32Size)
|
||||
case reflect.Float64:
|
||||
err = SetValueWithFloatX(value, defValue, Float64Size)
|
||||
err = setValueWithFloatX(value, defValue, Float64Size)
|
||||
case reflect.Slice:
|
||||
sp, ok := field.Tag.Lookup("separator")
|
||||
if !ok {
|
||||
sp = ":"
|
||||
}
|
||||
|
||||
err = SetValueWithSlice(value, defValue, sp)
|
||||
err = setValueWithSlice(value, defValue, sp)
|
||||
default:
|
||||
err = fmt.Errorf("%w: %s", errUnsupportedType, value.Kind().String())
|
||||
}
|
||||
|
@ -16,9 +16,11 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
@ -116,3 +118,18 @@ func TestYamlConfigFile(t *testing.T) {
|
||||
assert.Equal(DB_LOG_PATH, conf.Log.Path)
|
||||
assert.Equal(DB_LOG_LEVEL, conf.Log.Level)
|
||||
}
|
||||
|
||||
func TestTextUnmarshal(t *testing.T) {
|
||||
var level slog.Level
|
||||
|
||||
assert := assert.New(t)
|
||||
value := reflect.ValueOf(&level)
|
||||
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
assert.True(value.Type().Implements(unmarshalerType))
|
||||
|
||||
valueDecoder, ok := value.Interface().(encoding.TextUnmarshaler)
|
||||
assert.True(ok)
|
||||
assert.NotNil(valueDecoder)
|
||||
assert.NoError(valueDecoder.UnmarshalText([]byte("INFO")))
|
||||
assert.Equal(slog.LevelInfo, level)
|
||||
}
|
||||
|
42
env.go
42
env.go
@ -196,34 +196,46 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str
|
||||
return fmt.Errorf("%w: %s", errValueCannotBeChanged, field.Name)
|
||||
}
|
||||
|
||||
if setUnmarshalTextValue(value, envValue) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ok, err := setDurationValue(value, envValue); ok {
|
||||
return err
|
||||
}
|
||||
|
||||
return setSimpleEnvValue(value, field, envValue)
|
||||
}
|
||||
|
||||
func setSimpleEnvValue(value reflect.Value, field reflect.StructField, str string) error {
|
||||
var err error
|
||||
|
||||
kind := value.Kind()
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
err = SetValueWithBool(value, envValue)
|
||||
err = setValueWithBool(value, str)
|
||||
case reflect.String:
|
||||
value.SetString(envValue)
|
||||
value.SetString(str)
|
||||
case reflect.Int8:
|
||||
err = SetValueWithIntX(value, envValue, Int8Size)
|
||||
err = setValueWithIntX(value, str, Int8Size)
|
||||
case reflect.Int16:
|
||||
err = SetValueWithIntX(value, envValue, Int16Size)
|
||||
err = setValueWithIntX(value, str, Int16Size)
|
||||
case reflect.Int, reflect.Int32:
|
||||
err = SetValueWithIntX(value, envValue, Int32Size)
|
||||
err = setValueWithIntX(value, str, Int32Size)
|
||||
case reflect.Int64:
|
||||
err = SetValueWithIntX(value, envValue, Int64Size)
|
||||
err = setValueWithIntX(value, str, Int64Size)
|
||||
case reflect.Uint8:
|
||||
err = SetValueWithUintX(value, envValue, Uint8Size)
|
||||
err = setValueWithUintX(value, str, Uint8Size)
|
||||
case reflect.Uint16:
|
||||
err = SetValueWithUintX(value, envValue, Uint16Size)
|
||||
err = setValueWithUintX(value, str, Uint16Size)
|
||||
case reflect.Uint, reflect.Uint32:
|
||||
err = SetValueWithUintX(value, envValue, Uint32Size)
|
||||
err = setValueWithUintX(value, str, Uint32Size)
|
||||
case reflect.Uint64:
|
||||
err = SetValueWithUintX(value, envValue, Uint64Size)
|
||||
err = setValueWithUintX(value, str, Uint64Size)
|
||||
case reflect.Float32:
|
||||
err = SetValueWithFloatX(value, envValue, Float32Size)
|
||||
err = setValueWithFloatX(value, str, Float32Size)
|
||||
case reflect.Float64:
|
||||
err = SetValueWithFloatX(value, envValue, Float64Size)
|
||||
err = setValueWithFloatX(value, str, Float64Size)
|
||||
|
||||
case reflect.Slice:
|
||||
sp, ok := field.Tag.Lookup("separator")
|
||||
@ -231,17 +243,17 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str
|
||||
sp = ":"
|
||||
}
|
||||
|
||||
err = SetValueWithSlice(value, envValue, sp)
|
||||
err = setValueWithSlice(value, str, sp)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", field.Name, err)
|
||||
err = fmt.Errorf("%s: %w", field.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func usageFieldValueEnv(out io.Writer, field reflect.StructField, prefix string) error {
|
||||
|
92
utils.go
92
utils.go
@ -18,14 +18,13 @@ package config
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SetValueWithBool(value reflect.Value, boolStr string) error {
|
||||
func setValueWithBool(value reflect.Value, boolStr string) error {
|
||||
boolValue, err := strconv.ParseBool(boolStr)
|
||||
if err != nil {
|
||||
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
|
||||
@ -44,7 +43,7 @@ func SetValueWithBool(value reflect.Value, boolStr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetValueWithDuration(value reflect.Value, str string) error {
|
||||
func setValueWithDuration(value reflect.Value, str string) error {
|
||||
d, err := time.ParseDuration(str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse duration: %w", err)
|
||||
@ -55,19 +54,7 @@ func SetValueWithDuration(value reflect.Value, str string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetValueSlogLevel(value reflect.Value, str string) error {
|
||||
var level slog.Level
|
||||
|
||||
if err := level.UnmarshalText([]byte(str)); err != nil {
|
||||
return fmt.Errorf("parse slog level: %w", err)
|
||||
}
|
||||
|
||||
value.Set(reflect.ValueOf(level))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error {
|
||||
func setValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error {
|
||||
floatValue, err := strconv.ParseFloat(floatStr, bitSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse float: %w", err)
|
||||
@ -78,7 +65,7 @@ func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
|
||||
func setValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
|
||||
intValue, err := strconv.ParseInt(intStr, 10, bitSize)
|
||||
if err != nil {
|
||||
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
|
||||
@ -97,7 +84,7 @@ func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
|
||||
func setValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
|
||||
uintValue, err := strconv.ParseUint(uintStr, 10, bitSize)
|
||||
if err != nil {
|
||||
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
|
||||
@ -116,7 +103,7 @@ func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
|
||||
func setValueWithSlice(value reflect.Value, slice string, sep string) error {
|
||||
data := strings.Split(slice, sep)
|
||||
|
||||
size := len(data)
|
||||
@ -124,36 +111,45 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
|
||||
slice := reflect.MakeSlice(value.Type(), size, size)
|
||||
|
||||
for index := range size {
|
||||
var err error
|
||||
|
||||
ele := slice.Index(index)
|
||||
str := data[index]
|
||||
|
||||
if ok, err := setDurationValue(ele, str); ok {
|
||||
return err
|
||||
}
|
||||
|
||||
if setUnmarshalTextValue(ele, str) {
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
kind := ele.Kind()
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
err = SetValueWithBool(ele, data[index])
|
||||
err = setValueWithBool(ele, str)
|
||||
case reflect.String:
|
||||
ele.SetString(data[index])
|
||||
ele.SetString(str)
|
||||
case reflect.Uint8:
|
||||
err = SetValueWithUintX(ele, data[index], Uint8Size)
|
||||
err = setValueWithUintX(ele, str, Uint8Size)
|
||||
case reflect.Uint16:
|
||||
err = SetValueWithUintX(ele, data[index], Uint16Size)
|
||||
err = setValueWithUintX(ele, str, Uint16Size)
|
||||
case reflect.Uint, reflect.Uint32:
|
||||
err = SetValueWithUintX(ele, data[index], Uint32Size)
|
||||
err = setValueWithUintX(ele, str, Uint32Size)
|
||||
case reflect.Uint64:
|
||||
err = SetValueWithUintX(ele, data[index], Uint64Size)
|
||||
err = setValueWithUintX(ele, str, Uint64Size)
|
||||
case reflect.Int8:
|
||||
err = SetValueWithIntX(ele, data[index], Int8Size)
|
||||
err = setValueWithIntX(ele, str, Int8Size)
|
||||
case reflect.Int16:
|
||||
err = SetValueWithIntX(ele, data[index], Int16Size)
|
||||
err = setValueWithIntX(ele, str, Int16Size)
|
||||
case reflect.Int, reflect.Int32:
|
||||
err = SetValueWithIntX(ele, data[index], Int32Size)
|
||||
err = setValueWithIntX(ele, str, Int32Size)
|
||||
case reflect.Int64:
|
||||
err = SetValueWithIntX(ele, data[index], Int64Size)
|
||||
err = setValueWithIntX(ele, str, Int64Size)
|
||||
case reflect.Float32:
|
||||
err = SetValueWithFloatX(ele, data[index], Float32Size)
|
||||
err = setValueWithFloatX(ele, str, Float32Size)
|
||||
case reflect.Float64:
|
||||
err = SetValueWithFloatX(ele, data[index], Float64Size)
|
||||
err = setValueWithFloatX(ele, str, Float64Size)
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
|
||||
}
|
||||
@ -168,3 +164,33 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUnmarshalTextValue(value reflect.Value, str string) bool {
|
||||
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
valuePtr := reflect.New(value.Type())
|
||||
|
||||
if valuePtr.Type().Implements(unmarshalerType) {
|
||||
if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok {
|
||||
if err := decoder.UnmarshalText([]byte(str)); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
value.Set(valuePtr.Elem())
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func setDurationValue(value reflect.Value, str string) (bool, error) {
|
||||
valueTypePkgPath := value.Type().PkgPath()
|
||||
valueTypeName := value.Type().Name()
|
||||
|
||||
if valueTypePkgPath == "time" && valueTypeName == "Duration" {
|
||||
return true, setValueWithDuration(value, str)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ func TestSetValueWithBool(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("BoolValue")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithBool(v, "true"))
|
||||
assert.NoError(setValueWithBool(v, "true"))
|
||||
assert.Equal(true, d.BoolValue)
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ func TestSetValueWithFloat32(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("Float32Value")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithFloatX(v, "123.456", 32))
|
||||
assert.NoError(setValueWithFloatX(v, "123.456", 32))
|
||||
assert.Equal(float32(123.456), d.Float32Value)
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ func TestSetValueWithInt8(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("Int8Value")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithIntX(v, "10", 8))
|
||||
assert.NoError(setValueWithIntX(v, "10", 8))
|
||||
assert.Equal(int8(10), d.Int8Value)
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ func TestSetValueWithInt(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("IntValue")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithIntX(v, "10000", 32))
|
||||
assert.NoError(setValueWithIntX(v, "10000", 32))
|
||||
assert.Equal(10000, d.IntValue)
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ func TestSetValueWithUint16(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("Uint16Value")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithUintX(v, "100", 16))
|
||||
assert.NoError(setValueWithUintX(v, "100", 16))
|
||||
assert.Equal(uint16(100), d.Uint16Value)
|
||||
}
|
||||
|
||||
@ -82,7 +82,7 @@ func TestSetValueWithUint(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("UintValue")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithUintX(v, "2000", 32))
|
||||
assert.NoError(setValueWithUintX(v, "2000", 32))
|
||||
assert.Equal(uint(2000), d.UintValue)
|
||||
}
|
||||
|
||||
@ -91,7 +91,7 @@ func TestSetValueWithSlice(t *testing.T) {
|
||||
v := reflect.ValueOf(&d).Elem().FieldByName("Names")
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.NoError(SetValueWithSlice(v, "xx:yy:zz", ":"))
|
||||
assert.NoError(setValueWithSlice(v, "xx:yy:zz", ":"))
|
||||
assert.Equal(3, len(d.Names))
|
||||
assert.Equal("xx", d.Names[0])
|
||||
assert.Equal("yy", d.Names[1])
|
||||
|
Loading…
Reference in New Issue
Block a user