2
0

Добавлена поддержка любых параметров, реализующих интерфейс encoding.TextUnmarshaler
All checks were successful
test / test (push) Successful in 58s

This commit is contained in:
Алексей Бадяев 2024-10-24 16:56:21 +07:00
parent 2de4eecb68
commit 98be07131b
Signed by: alexey
GPG Key ID: 686FBC1363E4AFAE
6 changed files with 165 additions and 89 deletions

46
cli.go
View File

@ -16,6 +16,8 @@
package config package config
import ( import (
"encoding"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -26,6 +28,8 @@ import (
"unsafe" "unsafe"
) )
var ErrInternalError = errors.New("internal error")
// anyValue wraps a reflect.Value object and implements flag.Value interface // anyValue wraps a reflect.Value object and implements flag.Value interface
// the reflect.Value could be Bool, String, Int, Uint and Float // the reflect.Value could be Bool, String, Int, Uint and Float
type anyValue struct { type anyValue struct {
@ -85,25 +89,25 @@ func (av *anyValue) Set(value string) error {
case reflect.String: case reflect.String:
av.any.SetString(value) av.any.SetString(value)
case reflect.Float32: case reflect.Float32:
return SetValueWithFloatX(av.any, value, Float32Size) return setValueWithFloatX(av.any, value, Float32Size)
case reflect.Float64: case reflect.Float64:
return SetValueWithFloatX(av.any, value, Float64Size) return setValueWithFloatX(av.any, value, Float64Size)
case reflect.Int8: case reflect.Int8:
return SetValueWithIntX(av.any, value, Int8Size) return setValueWithIntX(av.any, value, Int8Size)
case reflect.Int16: case reflect.Int16:
return SetValueWithIntX(av.any, value, Int16Size) return setValueWithIntX(av.any, value, Int16Size)
case reflect.Int, reflect.Int32: case reflect.Int, reflect.Int32:
return SetValueWithIntX(av.any, value, Int32Size) return setValueWithIntX(av.any, value, Int32Size)
case reflect.Int64: case reflect.Int64:
return SetValueWithIntX(av.any, value, Int64Size) return setValueWithIntX(av.any, value, Int64Size)
case reflect.Uint8: case reflect.Uint8:
return SetValueWithUintX(av.any, value, Uint8Size) return setValueWithUintX(av.any, value, Uint8Size)
case reflect.Uint16: case reflect.Uint16:
return SetValueWithUintX(av.any, value, Uint16Size) return setValueWithUintX(av.any, value, Uint16Size)
case reflect.Uint, reflect.Uint32: case reflect.Uint, reflect.Uint32:
return SetValueWithUintX(av.any, value, Uint32Size) return setValueWithUintX(av.any, value, Uint32Size)
case reflect.Uint64: case reflect.Uint64:
return SetValueWithUintX(av.any, value, Uint64Size) return setValueWithUintX(av.any, value, Uint64Size)
default: default:
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
} }
@ -132,7 +136,7 @@ func (sv *sliceValue) Set(value string) error {
sp = ":" sp = ":"
} }
return SetValueWithSlice(sv.value, value, sp) return setValueWithSlice(sv.value, value, sp)
} }
// errorHandling is a global flag.ErrorHandling // errorHandling is a global flag.ErrorHandling
@ -291,6 +295,26 @@ func (c *Command) addFlag(value reflect.Value, field reflect.StructField) error
return nil return nil
} }
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
valuePtr := reflect.New(value.Type())
if valuePtr.Type().Implements(unmarshalerType) {
c.FlagSet.Func(name, usage, func(s string) error {
if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok {
if err := decoder.UnmarshalText([]byte(s)); err != nil {
return fmt.Errorf("unmarshal text: %w", err)
}
value.Set(valuePtr.Elem())
return nil
}
return ErrInternalError
},
)
}
kind := value.Kind() kind := value.Kind()
switch kind { switch kind {
case reflect.Bool: case reflect.Bool:

View File

@ -102,51 +102,48 @@ func parseValue(value reflect.Value) error {
} }
func setValue(value reflect.Value, defValue string, field reflect.StructField) error { func setValue(value reflect.Value, defValue string, field reflect.StructField) error {
if setUnmarshalTextValue(value, defValue) {
return nil
}
if ok, err := setDurationValue(value, defValue); ok {
return err
}
var err error var err error
valueTypePkgPath := value.Type().PkgPath()
valueTypeName := value.Type().Name()
if valueTypePkgPath == "time" && valueTypeName == "Duration" {
return SetValueWithDuration(value, defValue)
}
if valueTypePkgPath == "log/slog" && valueTypeName == "Level" {
return SetValueSlogLevel(value, defValue)
}
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
err = SetValueWithBool(value, defValue) err = setValueWithBool(value, defValue)
case reflect.String: case reflect.String:
value.SetString(defValue) value.SetString(defValue)
case reflect.Int8: case reflect.Int8:
err = SetValueWithIntX(value, defValue, Int8Size) err = setValueWithIntX(value, defValue, Int8Size)
case reflect.Int16: case reflect.Int16:
err = SetValueWithIntX(value, defValue, Int16Size) err = setValueWithIntX(value, defValue, Int16Size)
case reflect.Int, reflect.Int32: case reflect.Int, reflect.Int32:
err = SetValueWithIntX(value, defValue, Int32Size) err = setValueWithIntX(value, defValue, Int32Size)
case reflect.Int64: case reflect.Int64:
err = SetValueWithIntX(value, defValue, Int64Size) err = setValueWithIntX(value, defValue, Int64Size)
case reflect.Uint8: case reflect.Uint8:
err = SetValueWithUintX(value, defValue, Uint8Size) err = setValueWithUintX(value, defValue, Uint8Size)
case reflect.Uint16: case reflect.Uint16:
err = SetValueWithUintX(value, defValue, Uint16Size) err = setValueWithUintX(value, defValue, Uint16Size)
case reflect.Uint, reflect.Uint32: case reflect.Uint, reflect.Uint32:
err = SetValueWithUintX(value, defValue, Uint32Size) err = setValueWithUintX(value, defValue, Uint32Size)
case reflect.Uint64: case reflect.Uint64:
err = SetValueWithUintX(value, defValue, Uint64Size) err = setValueWithUintX(value, defValue, Uint64Size)
case reflect.Float32: case reflect.Float32:
err = SetValueWithFloatX(value, defValue, Float32Size) err = setValueWithFloatX(value, defValue, Float32Size)
case reflect.Float64: case reflect.Float64:
err = SetValueWithFloatX(value, defValue, Float64Size) err = setValueWithFloatX(value, defValue, Float64Size)
case reflect.Slice: case reflect.Slice:
sp, ok := field.Tag.Lookup("separator") sp, ok := field.Tag.Lookup("separator")
if !ok { if !ok {
sp = ":" sp = ":"
} }
err = SetValueWithSlice(value, defValue, sp) err = setValueWithSlice(value, defValue, sp)
default: default:
err = fmt.Errorf("%w: %s", errUnsupportedType, value.Kind().String()) err = fmt.Errorf("%w: %s", errUnsupportedType, value.Kind().String())
} }

View File

@ -16,9 +16,11 @@
package config package config
import ( import (
"encoding"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"testing" "testing"
@ -116,3 +118,18 @@ func TestYamlConfigFile(t *testing.T) {
assert.Equal(DB_LOG_PATH, conf.Log.Path) assert.Equal(DB_LOG_PATH, conf.Log.Path)
assert.Equal(DB_LOG_LEVEL, conf.Log.Level) assert.Equal(DB_LOG_LEVEL, conf.Log.Level)
} }
func TestTextUnmarshal(t *testing.T) {
var level slog.Level
assert := assert.New(t)
value := reflect.ValueOf(&level)
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
assert.True(value.Type().Implements(unmarshalerType))
valueDecoder, ok := value.Interface().(encoding.TextUnmarshaler)
assert.True(ok)
assert.NotNil(valueDecoder)
assert.NoError(valueDecoder.UnmarshalText([]byte("INFO")))
assert.Equal(slog.LevelInfo, level)
}

42
env.go
View File

@ -196,34 +196,46 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str
return fmt.Errorf("%w: %s", errValueCannotBeChanged, field.Name) return fmt.Errorf("%w: %s", errValueCannotBeChanged, field.Name)
} }
if setUnmarshalTextValue(value, envValue) {
return nil
}
if ok, err := setDurationValue(value, envValue); ok {
return err
}
return setSimpleEnvValue(value, field, envValue)
}
func setSimpleEnvValue(value reflect.Value, field reflect.StructField, str string) error {
var err error var err error
kind := value.Kind() kind := value.Kind()
switch kind { switch kind {
case reflect.Bool: case reflect.Bool:
err = SetValueWithBool(value, envValue) err = setValueWithBool(value, str)
case reflect.String: case reflect.String:
value.SetString(envValue) value.SetString(str)
case reflect.Int8: case reflect.Int8:
err = SetValueWithIntX(value, envValue, Int8Size) err = setValueWithIntX(value, str, Int8Size)
case reflect.Int16: case reflect.Int16:
err = SetValueWithIntX(value, envValue, Int16Size) err = setValueWithIntX(value, str, Int16Size)
case reflect.Int, reflect.Int32: case reflect.Int, reflect.Int32:
err = SetValueWithIntX(value, envValue, Int32Size) err = setValueWithIntX(value, str, Int32Size)
case reflect.Int64: case reflect.Int64:
err = SetValueWithIntX(value, envValue, Int64Size) err = setValueWithIntX(value, str, Int64Size)
case reflect.Uint8: case reflect.Uint8:
err = SetValueWithUintX(value, envValue, Uint8Size) err = setValueWithUintX(value, str, Uint8Size)
case reflect.Uint16: case reflect.Uint16:
err = SetValueWithUintX(value, envValue, Uint16Size) err = setValueWithUintX(value, str, Uint16Size)
case reflect.Uint, reflect.Uint32: case reflect.Uint, reflect.Uint32:
err = SetValueWithUintX(value, envValue, Uint32Size) err = setValueWithUintX(value, str, Uint32Size)
case reflect.Uint64: case reflect.Uint64:
err = SetValueWithUintX(value, envValue, Uint64Size) err = setValueWithUintX(value, str, Uint64Size)
case reflect.Float32: case reflect.Float32:
err = SetValueWithFloatX(value, envValue, Float32Size) err = setValueWithFloatX(value, str, Float32Size)
case reflect.Float64: case reflect.Float64:
err = SetValueWithFloatX(value, envValue, Float64Size) err = setValueWithFloatX(value, str, Float64Size)
case reflect.Slice: case reflect.Slice:
sp, ok := field.Tag.Lookup("separator") sp, ok := field.Tag.Lookup("separator")
@ -231,17 +243,17 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str
sp = ":" sp = ":"
} }
err = SetValueWithSlice(value, envValue, sp) err = setValueWithSlice(value, str, sp)
default: default:
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
} }
if err != nil { if err != nil {
return fmt.Errorf("%s: %w", field.Name, err) err = fmt.Errorf("%s: %w", field.Name, err)
} }
return nil return err
} }
func usageFieldValueEnv(out io.Writer, field reflect.StructField, prefix string) error { func usageFieldValueEnv(out io.Writer, field reflect.StructField, prefix string) error {

View File

@ -18,14 +18,13 @@ package config
import ( import (
"encoding" "encoding"
"fmt" "fmt"
"log/slog"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
func SetValueWithBool(value reflect.Value, boolStr string) error { func setValueWithBool(value reflect.Value, boolStr string) error {
boolValue, err := strconv.ParseBool(boolStr) boolValue, err := strconv.ParseBool(boolStr)
if err != nil { if err != nil {
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok { if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
@ -44,7 +43,7 @@ func SetValueWithBool(value reflect.Value, boolStr string) error {
return nil return nil
} }
func SetValueWithDuration(value reflect.Value, str string) error { func setValueWithDuration(value reflect.Value, str string) error {
d, err := time.ParseDuration(str) d, err := time.ParseDuration(str)
if err != nil { if err != nil {
return fmt.Errorf("parse duration: %w", err) return fmt.Errorf("parse duration: %w", err)
@ -55,19 +54,7 @@ func SetValueWithDuration(value reflect.Value, str string) error {
return nil return nil
} }
func SetValueSlogLevel(value reflect.Value, str string) error { func setValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error {
var level slog.Level
if err := level.UnmarshalText([]byte(str)); err != nil {
return fmt.Errorf("parse slog level: %w", err)
}
value.Set(reflect.ValueOf(level))
return nil
}
func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error {
floatValue, err := strconv.ParseFloat(floatStr, bitSize) floatValue, err := strconv.ParseFloat(floatStr, bitSize)
if err != nil { if err != nil {
return fmt.Errorf("parse float: %w", err) return fmt.Errorf("parse float: %w", err)
@ -78,7 +65,7 @@ func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error
return nil return nil
} }
func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error { func setValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
intValue, err := strconv.ParseInt(intStr, 10, bitSize) intValue, err := strconv.ParseInt(intStr, 10, bitSize)
if err != nil { if err != nil {
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok { if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
@ -97,7 +84,7 @@ func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
return nil return nil
} }
func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error { func setValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
uintValue, err := strconv.ParseUint(uintStr, 10, bitSize) uintValue, err := strconv.ParseUint(uintStr, 10, bitSize)
if err != nil { if err != nil {
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok { if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
@ -116,7 +103,7 @@ func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
return nil return nil
} }
func SetValueWithSlice(value reflect.Value, slice string, sep string) error { func setValueWithSlice(value reflect.Value, slice string, sep string) error {
data := strings.Split(slice, sep) data := strings.Split(slice, sep)
size := len(data) size := len(data)
@ -124,36 +111,45 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
slice := reflect.MakeSlice(value.Type(), size, size) slice := reflect.MakeSlice(value.Type(), size, size)
for index := range size { for index := range size {
var err error
ele := slice.Index(index) ele := slice.Index(index)
str := data[index]
if ok, err := setDurationValue(ele, str); ok {
return err
}
if setUnmarshalTextValue(ele, str) {
continue
}
var err error
kind := ele.Kind() kind := ele.Kind()
switch kind { switch kind {
case reflect.Bool: case reflect.Bool:
err = SetValueWithBool(ele, data[index]) err = setValueWithBool(ele, str)
case reflect.String: case reflect.String:
ele.SetString(data[index]) ele.SetString(str)
case reflect.Uint8: case reflect.Uint8:
err = SetValueWithUintX(ele, data[index], Uint8Size) err = setValueWithUintX(ele, str, Uint8Size)
case reflect.Uint16: case reflect.Uint16:
err = SetValueWithUintX(ele, data[index], Uint16Size) err = setValueWithUintX(ele, str, Uint16Size)
case reflect.Uint, reflect.Uint32: case reflect.Uint, reflect.Uint32:
err = SetValueWithUintX(ele, data[index], Uint32Size) err = setValueWithUintX(ele, str, Uint32Size)
case reflect.Uint64: case reflect.Uint64:
err = SetValueWithUintX(ele, data[index], Uint64Size) err = setValueWithUintX(ele, str, Uint64Size)
case reflect.Int8: case reflect.Int8:
err = SetValueWithIntX(ele, data[index], Int8Size) err = setValueWithIntX(ele, str, Int8Size)
case reflect.Int16: case reflect.Int16:
err = SetValueWithIntX(ele, data[index], Int16Size) err = setValueWithIntX(ele, str, Int16Size)
case reflect.Int, reflect.Int32: case reflect.Int, reflect.Int32:
err = SetValueWithIntX(ele, data[index], Int32Size) err = setValueWithIntX(ele, str, Int32Size)
case reflect.Int64: case reflect.Int64:
err = SetValueWithIntX(ele, data[index], Int64Size) err = setValueWithIntX(ele, str, Int64Size)
case reflect.Float32: case reflect.Float32:
err = SetValueWithFloatX(ele, data[index], Float32Size) err = setValueWithFloatX(ele, str, Float32Size)
case reflect.Float64: case reflect.Float64:
err = SetValueWithFloatX(ele, data[index], Float64Size) err = setValueWithFloatX(ele, str, Float64Size)
default: default:
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
} }
@ -168,3 +164,33 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
return nil return nil
} }
func setUnmarshalTextValue(value reflect.Value, str string) bool {
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
valuePtr := reflect.New(value.Type())
if valuePtr.Type().Implements(unmarshalerType) {
if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok {
if err := decoder.UnmarshalText([]byte(str)); err != nil {
return false
}
value.Set(valuePtr.Elem())
return true
}
}
return false
}
func setDurationValue(value reflect.Value, str string) (bool, error) {
valueTypePkgPath := value.Type().PkgPath()
valueTypeName := value.Type().Name()
if valueTypePkgPath == "time" && valueTypeName == "Duration" {
return true, setValueWithDuration(value, str)
}
return false, nil
}

View File

@ -37,7 +37,7 @@ func TestSetValueWithBool(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("BoolValue") v := reflect.ValueOf(&d).Elem().FieldByName("BoolValue")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithBool(v, "true")) assert.NoError(setValueWithBool(v, "true"))
assert.Equal(true, d.BoolValue) assert.Equal(true, d.BoolValue)
} }
@ -46,7 +46,7 @@ func TestSetValueWithFloat32(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Float32Value") v := reflect.ValueOf(&d).Elem().FieldByName("Float32Value")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithFloatX(v, "123.456", 32)) assert.NoError(setValueWithFloatX(v, "123.456", 32))
assert.Equal(float32(123.456), d.Float32Value) assert.Equal(float32(123.456), d.Float32Value)
} }
@ -55,7 +55,7 @@ func TestSetValueWithInt8(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Int8Value") v := reflect.ValueOf(&d).Elem().FieldByName("Int8Value")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithIntX(v, "10", 8)) assert.NoError(setValueWithIntX(v, "10", 8))
assert.Equal(int8(10), d.Int8Value) assert.Equal(int8(10), d.Int8Value)
} }
@ -64,7 +64,7 @@ func TestSetValueWithInt(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("IntValue") v := reflect.ValueOf(&d).Elem().FieldByName("IntValue")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithIntX(v, "10000", 32)) assert.NoError(setValueWithIntX(v, "10000", 32))
assert.Equal(10000, d.IntValue) assert.Equal(10000, d.IntValue)
} }
@ -73,7 +73,7 @@ func TestSetValueWithUint16(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Uint16Value") v := reflect.ValueOf(&d).Elem().FieldByName("Uint16Value")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithUintX(v, "100", 16)) assert.NoError(setValueWithUintX(v, "100", 16))
assert.Equal(uint16(100), d.Uint16Value) assert.Equal(uint16(100), d.Uint16Value)
} }
@ -82,7 +82,7 @@ func TestSetValueWithUint(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("UintValue") v := reflect.ValueOf(&d).Elem().FieldByName("UintValue")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithUintX(v, "2000", 32)) assert.NoError(setValueWithUintX(v, "2000", 32))
assert.Equal(uint(2000), d.UintValue) assert.Equal(uint(2000), d.UintValue)
} }
@ -91,7 +91,7 @@ func TestSetValueWithSlice(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Names") v := reflect.ValueOf(&d).Elem().FieldByName("Names")
assert := assert.New(t) assert := assert.New(t)
assert.NoError(SetValueWithSlice(v, "xx:yy:zz", ":")) assert.NoError(setValueWithSlice(v, "xx:yy:zz", ":"))
assert.Equal(3, len(d.Names)) assert.Equal(3, len(d.Names))
assert.Equal("xx", d.Names[0]) assert.Equal("xx", d.Names[0])
assert.Equal("yy", d.Names[1]) assert.Equal("yy", d.Names[1])