2
0

Добавлена поддержка любых параметров, реализующих интерфейс encoding.TextUnmarshaler
All checks were successful
test / test (push) Successful in 58s

This commit is contained in:
Алексей Бадяев 2024-10-24 16:56:21 +07:00
parent 2de4eecb68
commit 98be07131b
Signed by: alexey
GPG Key ID: 686FBC1363E4AFAE
6 changed files with 165 additions and 89 deletions

46
cli.go
View File

@ -16,6 +16,8 @@
package config
import (
"encoding"
"errors"
"flag"
"fmt"
"io"
@ -26,6 +28,8 @@ import (
"unsafe"
)
var ErrInternalError = errors.New("internal error")
// anyValue wraps a reflect.Value object and implements flag.Value interface
// the reflect.Value could be Bool, String, Int, Uint and Float
type anyValue struct {
@ -85,25 +89,25 @@ func (av *anyValue) Set(value string) error {
case reflect.String:
av.any.SetString(value)
case reflect.Float32:
return SetValueWithFloatX(av.any, value, Float32Size)
return setValueWithFloatX(av.any, value, Float32Size)
case reflect.Float64:
return SetValueWithFloatX(av.any, value, Float64Size)
return setValueWithFloatX(av.any, value, Float64Size)
case reflect.Int8:
return SetValueWithIntX(av.any, value, Int8Size)
return setValueWithIntX(av.any, value, Int8Size)
case reflect.Int16:
return SetValueWithIntX(av.any, value, Int16Size)
return setValueWithIntX(av.any, value, Int16Size)
case reflect.Int, reflect.Int32:
return SetValueWithIntX(av.any, value, Int32Size)
return setValueWithIntX(av.any, value, Int32Size)
case reflect.Int64:
return SetValueWithIntX(av.any, value, Int64Size)
return setValueWithIntX(av.any, value, Int64Size)
case reflect.Uint8:
return SetValueWithUintX(av.any, value, Uint8Size)
return setValueWithUintX(av.any, value, Uint8Size)
case reflect.Uint16:
return SetValueWithUintX(av.any, value, Uint16Size)
return setValueWithUintX(av.any, value, Uint16Size)
case reflect.Uint, reflect.Uint32:
return SetValueWithUintX(av.any, value, Uint32Size)
return setValueWithUintX(av.any, value, Uint32Size)
case reflect.Uint64:
return SetValueWithUintX(av.any, value, Uint64Size)
return setValueWithUintX(av.any, value, Uint64Size)
default:
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
}
@ -132,7 +136,7 @@ func (sv *sliceValue) Set(value string) error {
sp = ":"
}
return SetValueWithSlice(sv.value, value, sp)
return setValueWithSlice(sv.value, value, sp)
}
// errorHandling is a global flag.ErrorHandling
@ -291,6 +295,26 @@ func (c *Command) addFlag(value reflect.Value, field reflect.StructField) error
return nil
}
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
valuePtr := reflect.New(value.Type())
if valuePtr.Type().Implements(unmarshalerType) {
c.FlagSet.Func(name, usage, func(s string) error {
if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok {
if err := decoder.UnmarshalText([]byte(s)); err != nil {
return fmt.Errorf("unmarshal text: %w", err)
}
value.Set(valuePtr.Elem())
return nil
}
return ErrInternalError
},
)
}
kind := value.Kind()
switch kind {
case reflect.Bool:

View File

@ -102,51 +102,48 @@ func parseValue(value reflect.Value) error {
}
func setValue(value reflect.Value, defValue string, field reflect.StructField) error {
if setUnmarshalTextValue(value, defValue) {
return nil
}
if ok, err := setDurationValue(value, defValue); ok {
return err
}
var err error
valueTypePkgPath := value.Type().PkgPath()
valueTypeName := value.Type().Name()
if valueTypePkgPath == "time" && valueTypeName == "Duration" {
return SetValueWithDuration(value, defValue)
}
if valueTypePkgPath == "log/slog" && valueTypeName == "Level" {
return SetValueSlogLevel(value, defValue)
}
switch value.Kind() {
case reflect.Bool:
err = SetValueWithBool(value, defValue)
err = setValueWithBool(value, defValue)
case reflect.String:
value.SetString(defValue)
case reflect.Int8:
err = SetValueWithIntX(value, defValue, Int8Size)
err = setValueWithIntX(value, defValue, Int8Size)
case reflect.Int16:
err = SetValueWithIntX(value, defValue, Int16Size)
err = setValueWithIntX(value, defValue, Int16Size)
case reflect.Int, reflect.Int32:
err = SetValueWithIntX(value, defValue, Int32Size)
err = setValueWithIntX(value, defValue, Int32Size)
case reflect.Int64:
err = SetValueWithIntX(value, defValue, Int64Size)
err = setValueWithIntX(value, defValue, Int64Size)
case reflect.Uint8:
err = SetValueWithUintX(value, defValue, Uint8Size)
err = setValueWithUintX(value, defValue, Uint8Size)
case reflect.Uint16:
err = SetValueWithUintX(value, defValue, Uint16Size)
err = setValueWithUintX(value, defValue, Uint16Size)
case reflect.Uint, reflect.Uint32:
err = SetValueWithUintX(value, defValue, Uint32Size)
err = setValueWithUintX(value, defValue, Uint32Size)
case reflect.Uint64:
err = SetValueWithUintX(value, defValue, Uint64Size)
err = setValueWithUintX(value, defValue, Uint64Size)
case reflect.Float32:
err = SetValueWithFloatX(value, defValue, Float32Size)
err = setValueWithFloatX(value, defValue, Float32Size)
case reflect.Float64:
err = SetValueWithFloatX(value, defValue, Float64Size)
err = setValueWithFloatX(value, defValue, Float64Size)
case reflect.Slice:
sp, ok := field.Tag.Lookup("separator")
if !ok {
sp = ":"
}
err = SetValueWithSlice(value, defValue, sp)
err = setValueWithSlice(value, defValue, sp)
default:
err = fmt.Errorf("%w: %s", errUnsupportedType, value.Kind().String())
}

View File

@ -16,9 +16,11 @@
package config
import (
"encoding"
"log/slog"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"testing"
@ -116,3 +118,18 @@ func TestYamlConfigFile(t *testing.T) {
assert.Equal(DB_LOG_PATH, conf.Log.Path)
assert.Equal(DB_LOG_LEVEL, conf.Log.Level)
}
func TestTextUnmarshal(t *testing.T) {
var level slog.Level
assert := assert.New(t)
value := reflect.ValueOf(&level)
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
assert.True(value.Type().Implements(unmarshalerType))
valueDecoder, ok := value.Interface().(encoding.TextUnmarshaler)
assert.True(ok)
assert.NotNil(valueDecoder)
assert.NoError(valueDecoder.UnmarshalText([]byte("INFO")))
assert.Equal(slog.LevelInfo, level)
}

42
env.go
View File

@ -196,34 +196,46 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str
return fmt.Errorf("%w: %s", errValueCannotBeChanged, field.Name)
}
if setUnmarshalTextValue(value, envValue) {
return nil
}
if ok, err := setDurationValue(value, envValue); ok {
return err
}
return setSimpleEnvValue(value, field, envValue)
}
func setSimpleEnvValue(value reflect.Value, field reflect.StructField, str string) error {
var err error
kind := value.Kind()
switch kind {
case reflect.Bool:
err = SetValueWithBool(value, envValue)
err = setValueWithBool(value, str)
case reflect.String:
value.SetString(envValue)
value.SetString(str)
case reflect.Int8:
err = SetValueWithIntX(value, envValue, Int8Size)
err = setValueWithIntX(value, str, Int8Size)
case reflect.Int16:
err = SetValueWithIntX(value, envValue, Int16Size)
err = setValueWithIntX(value, str, Int16Size)
case reflect.Int, reflect.Int32:
err = SetValueWithIntX(value, envValue, Int32Size)
err = setValueWithIntX(value, str, Int32Size)
case reflect.Int64:
err = SetValueWithIntX(value, envValue, Int64Size)
err = setValueWithIntX(value, str, Int64Size)
case reflect.Uint8:
err = SetValueWithUintX(value, envValue, Uint8Size)
err = setValueWithUintX(value, str, Uint8Size)
case reflect.Uint16:
err = SetValueWithUintX(value, envValue, Uint16Size)
err = setValueWithUintX(value, str, Uint16Size)
case reflect.Uint, reflect.Uint32:
err = SetValueWithUintX(value, envValue, Uint32Size)
err = setValueWithUintX(value, str, Uint32Size)
case reflect.Uint64:
err = SetValueWithUintX(value, envValue, Uint64Size)
err = setValueWithUintX(value, str, Uint64Size)
case reflect.Float32:
err = SetValueWithFloatX(value, envValue, Float32Size)
err = setValueWithFloatX(value, str, Float32Size)
case reflect.Float64:
err = SetValueWithFloatX(value, envValue, Float64Size)
err = setValueWithFloatX(value, str, Float64Size)
case reflect.Slice:
sp, ok := field.Tag.Lookup("separator")
@ -231,17 +243,17 @@ func setFieldValueEnv(value reflect.Value, field reflect.StructField, prefix str
sp = ":"
}
err = SetValueWithSlice(value, envValue, sp)
err = setValueWithSlice(value, str, sp)
default:
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
}
if err != nil {
return fmt.Errorf("%s: %w", field.Name, err)
err = fmt.Errorf("%s: %w", field.Name, err)
}
return nil
return err
}
func usageFieldValueEnv(out io.Writer, field reflect.StructField, prefix string) error {

View File

@ -18,14 +18,13 @@ package config
import (
"encoding"
"fmt"
"log/slog"
"reflect"
"strconv"
"strings"
"time"
)
func SetValueWithBool(value reflect.Value, boolStr string) error {
func setValueWithBool(value reflect.Value, boolStr string) error {
boolValue, err := strconv.ParseBool(boolStr)
if err != nil {
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
@ -44,7 +43,7 @@ func SetValueWithBool(value reflect.Value, boolStr string) error {
return nil
}
func SetValueWithDuration(value reflect.Value, str string) error {
func setValueWithDuration(value reflect.Value, str string) error {
d, err := time.ParseDuration(str)
if err != nil {
return fmt.Errorf("parse duration: %w", err)
@ -55,19 +54,7 @@ func SetValueWithDuration(value reflect.Value, str string) error {
return nil
}
func SetValueSlogLevel(value reflect.Value, str string) error {
var level slog.Level
if err := level.UnmarshalText([]byte(str)); err != nil {
return fmt.Errorf("parse slog level: %w", err)
}
value.Set(reflect.ValueOf(level))
return nil
}
func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error {
func setValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error {
floatValue, err := strconv.ParseFloat(floatStr, bitSize)
if err != nil {
return fmt.Errorf("parse float: %w", err)
@ -78,7 +65,7 @@ func SetValueWithFloatX(value reflect.Value, floatStr string, bitSize int) error
return nil
}
func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
func setValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
intValue, err := strconv.ParseInt(intStr, 10, bitSize)
if err != nil {
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
@ -97,7 +84,7 @@ func SetValueWithIntX(value reflect.Value, intStr string, bitSize int) error {
return nil
}
func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
func setValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
uintValue, err := strconv.ParseUint(uintStr, 10, bitSize)
if err != nil {
if txt, ok := value.Interface().(encoding.TextUnmarshaler); ok {
@ -116,7 +103,7 @@ func SetValueWithUintX(value reflect.Value, uintStr string, bitSize int) error {
return nil
}
func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
func setValueWithSlice(value reflect.Value, slice string, sep string) error {
data := strings.Split(slice, sep)
size := len(data)
@ -124,36 +111,45 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
slice := reflect.MakeSlice(value.Type(), size, size)
for index := range size {
var err error
ele := slice.Index(index)
str := data[index]
if ok, err := setDurationValue(ele, str); ok {
return err
}
if setUnmarshalTextValue(ele, str) {
continue
}
var err error
kind := ele.Kind()
switch kind {
case reflect.Bool:
err = SetValueWithBool(ele, data[index])
err = setValueWithBool(ele, str)
case reflect.String:
ele.SetString(data[index])
ele.SetString(str)
case reflect.Uint8:
err = SetValueWithUintX(ele, data[index], Uint8Size)
err = setValueWithUintX(ele, str, Uint8Size)
case reflect.Uint16:
err = SetValueWithUintX(ele, data[index], Uint16Size)
err = setValueWithUintX(ele, str, Uint16Size)
case reflect.Uint, reflect.Uint32:
err = SetValueWithUintX(ele, data[index], Uint32Size)
err = setValueWithUintX(ele, str, Uint32Size)
case reflect.Uint64:
err = SetValueWithUintX(ele, data[index], Uint64Size)
err = setValueWithUintX(ele, str, Uint64Size)
case reflect.Int8:
err = SetValueWithIntX(ele, data[index], Int8Size)
err = setValueWithIntX(ele, str, Int8Size)
case reflect.Int16:
err = SetValueWithIntX(ele, data[index], Int16Size)
err = setValueWithIntX(ele, str, Int16Size)
case reflect.Int, reflect.Int32:
err = SetValueWithIntX(ele, data[index], Int32Size)
err = setValueWithIntX(ele, str, Int32Size)
case reflect.Int64:
err = SetValueWithIntX(ele, data[index], Int64Size)
err = setValueWithIntX(ele, str, Int64Size)
case reflect.Float32:
err = SetValueWithFloatX(ele, data[index], Float32Size)
err = setValueWithFloatX(ele, str, Float32Size)
case reflect.Float64:
err = SetValueWithFloatX(ele, data[index], Float64Size)
err = setValueWithFloatX(ele, str, Float64Size)
default:
return fmt.Errorf("%w: %s", errUnsupportedType, kind.String())
}
@ -168,3 +164,33 @@ func SetValueWithSlice(value reflect.Value, slice string, sep string) error {
return nil
}
func setUnmarshalTextValue(value reflect.Value, str string) bool {
unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
valuePtr := reflect.New(value.Type())
if valuePtr.Type().Implements(unmarshalerType) {
if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok {
if err := decoder.UnmarshalText([]byte(str)); err != nil {
return false
}
value.Set(valuePtr.Elem())
return true
}
}
return false
}
func setDurationValue(value reflect.Value, str string) (bool, error) {
valueTypePkgPath := value.Type().PkgPath()
valueTypeName := value.Type().Name()
if valueTypePkgPath == "time" && valueTypeName == "Duration" {
return true, setValueWithDuration(value, str)
}
return false, nil
}

View File

@ -37,7 +37,7 @@ func TestSetValueWithBool(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("BoolValue")
assert := assert.New(t)
assert.NoError(SetValueWithBool(v, "true"))
assert.NoError(setValueWithBool(v, "true"))
assert.Equal(true, d.BoolValue)
}
@ -46,7 +46,7 @@ func TestSetValueWithFloat32(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Float32Value")
assert := assert.New(t)
assert.NoError(SetValueWithFloatX(v, "123.456", 32))
assert.NoError(setValueWithFloatX(v, "123.456", 32))
assert.Equal(float32(123.456), d.Float32Value)
}
@ -55,7 +55,7 @@ func TestSetValueWithInt8(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Int8Value")
assert := assert.New(t)
assert.NoError(SetValueWithIntX(v, "10", 8))
assert.NoError(setValueWithIntX(v, "10", 8))
assert.Equal(int8(10), d.Int8Value)
}
@ -64,7 +64,7 @@ func TestSetValueWithInt(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("IntValue")
assert := assert.New(t)
assert.NoError(SetValueWithIntX(v, "10000", 32))
assert.NoError(setValueWithIntX(v, "10000", 32))
assert.Equal(10000, d.IntValue)
}
@ -73,7 +73,7 @@ func TestSetValueWithUint16(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Uint16Value")
assert := assert.New(t)
assert.NoError(SetValueWithUintX(v, "100", 16))
assert.NoError(setValueWithUintX(v, "100", 16))
assert.Equal(uint16(100), d.Uint16Value)
}
@ -82,7 +82,7 @@ func TestSetValueWithUint(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("UintValue")
assert := assert.New(t)
assert.NoError(SetValueWithUintX(v, "2000", 32))
assert.NoError(setValueWithUintX(v, "2000", 32))
assert.Equal(uint(2000), d.UintValue)
}
@ -91,7 +91,7 @@ func TestSetValueWithSlice(t *testing.T) {
v := reflect.ValueOf(&d).Elem().FieldByName("Names")
assert := assert.New(t)
assert.NoError(SetValueWithSlice(v, "xx:yy:zz", ":"))
assert.NoError(setValueWithSlice(v, "xx:yy:zz", ":"))
assert.Equal(3, len(d.Names))
assert.Equal("xx", d.Names[0])
assert.Equal("yy", d.Names[1])