/* * Copyright (C) 2017 eschao * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "encoding" "errors" "flag" "fmt" "io" "reflect" "strconv" "strings" "time" "unsafe" ) var ErrInternalError = errors.New("internal error") // anyValue wraps a reflect.Value object and implements flag.Value interface // the reflect.Value could be Bool, String, Int, Uint and Float type anyValue struct { any reflect.Value } // newAnyValue creates an anyValue object func newAnyValue(v reflect.Value) *anyValue { return &anyValue{any: v} } func (av *anyValue) String() string { kind := av.any.Kind() switch kind { case reflect.Bool: return strconv.FormatBool(av.any.Bool()) case reflect.String: return av.any.String() case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: return strconv.FormatInt(av.any.Int(), 10) case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: return strconv.FormatUint(av.any.Uint(), 10) case reflect.Float32: return strconv.FormatFloat(av.any.Float(), 'E', -1, 32) case reflect.Float64: return strconv.FormatFloat(av.any.Float(), 'E', -1, 64) } return "unsupported type: " + kind.String() } const ( Float32Size = 32 Float64Size = 64 Int8Size = 8 Int16Size = 16 Int32Size = 32 Int64Size = 64 Uint8Size = 8 Uint16Size = 16 Uint32Size = 32 UintSize = 32 Uint64Size = 64 ) func (av *anyValue) Set(value string) error { kind := av.any.Kind() switch kind { case reflect.String: av.any.SetString(value) case reflect.Float32: return setValueWithFloatX(av.any, value, Float32Size) case reflect.Float64: return setValueWithFloatX(av.any, value, Float64Size) case reflect.Int8: return setValueWithIntX(av.any, value, Int8Size) case reflect.Int16: return setValueWithIntX(av.any, value, Int16Size) case reflect.Int, reflect.Int32: return setValueWithIntX(av.any, value, Int32Size) case reflect.Int64: return setValueWithIntX(av.any, value, Int64Size) case reflect.Uint8: return setValueWithUintX(av.any, value, Uint8Size) case reflect.Uint16: return setValueWithUintX(av.any, value, Uint16Size) case reflect.Uint, reflect.Uint32: return setValueWithUintX(av.any, value, Uint32Size) case reflect.Uint64: return setValueWithUintX(av.any, value, Uint64Size) default: return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) } return nil } // sliceValue wraps a reflect.Value object and implements flag.Value interface // the reflect.Value could only be a sliceable type type sliceValue struct { value reflect.Value separator string } func newSliceValue(v reflect.Value, separator string) *sliceValue { return &sliceValue{value: v, separator: separator} } func (sv *sliceValue) String() string { return sv.value.String() } func (sv *sliceValue) Set(value string) error { sp := sv.separator if sp == "" { sp = ":" } return setValueWithSlice(sv.value, value, sp) } // errorHandling is a global flag.ErrorHandling var errorHandling = flag.ExitOnError // UsageFunc defines a callback function for printing command usage type UsageFunc func(*Command) func() // usageHandler is a global UsageFunc callback, default is nil which means it // will use default flag.Usage function var usageHandler UsageFunc = nil // Command defines a command line structure type Command struct { Name string // Command name prefix string // Args prefix FlagSet *flag.FlagSet // Command arguments Usage string // Command usage description SubCommands map[string]*Command // Sub-commands Args []string // Rest of args after parsing } // NewCLI creates a command with given name, the command will use default // ErrorHandling: flag.ExitOnError and default usage function: flag.Usage func NewCLI(name string) *Command { cmd := Command{ Name: name, FlagSet: flag.NewFlagSet(name, errorHandling), SubCommands: make(map[string]*Command), } return &cmd } // NewCliWith creates a command with given name, error handling and customized // usage function func NewCliWith( name string, errHandling flag.ErrorHandling, usageHandling UsageFunc, ) *Command { errorHandling = errHandling usageHandler = usageHandling cmd := Command{ Name: name, FlagSet: flag.NewFlagSet(name, errorHandling), SubCommands: make(map[string]*Command), } if usageHandler != nil { cmd.FlagSet.Usage = usageHandler(&cmd) } return &cmd } // Init analyzes the given structure interface, extracts cli definitions from // its tag and installs command flagset by flag APIs. The interface must be a // structure pointer, otherwise will return an error func (c *Command) Init(in interface{}) error { ptrRef := reflect.ValueOf(in) if ptrRef.IsNil() || ptrRef.Kind() != reflect.Ptr { return fmt.Errorf("%w: %s", errExpectStructPointerInsteadOf, ptrRef.Kind().String()) } valueOfStruct := ptrRef.Elem() if valueOfStruct.Kind() != reflect.Struct { return fmt.Errorf("%w: %s", errExpectStructPointerInsteadOf, valueOfStruct.Kind().String()) } return c.parseValue(valueOfStruct) } // Capture filter config args and return rest args func (c *Command) Capture(args []string) []string { var ( result []string expectValue bool ) for _, arg := range args { if !strings.HasPrefix(arg, "-") { if expectValue { c.Args = append(c.Args, arg) expectValue = false } else { result = append(result, arg) } continue } name := strings.TrimPrefix(strings.TrimPrefix(arg, "-"), "-") parts := strings.Split(name, "=") if c.FlagSet.Lookup(parts[0]) != nil { c.Args = append(c.Args, arg) expectValue = len(parts) == 1 } else { result = append(result, arg) } } return result } // parseValue parses a reflect.Value object and extracts cli definitions func (c *Command) parseValue(value reflect.Value) error { var err error typeOfStruct := value.Type() for i := 0; i < value.NumField() && err == nil; i++ { valueOfField := value.Field(i) kindOfField := valueOfField.Kind() structOfField := typeOfStruct.Field(i) switch kindOfField { case reflect.Ptr: if !valueOfField.IsNil() && valueOfField.CanSet() { cmd := c.createSubCommand(structOfField.Tag) err = cmd.Init(valueOfField.Interface()) } case reflect.Struct: cmd := c.createSubCommand(structOfField.Tag) err = cmd.parseValue(valueOfField) default: err = c.addFlag(valueOfField, structOfField) } } return err } // addFlag installs a command flag variable by flag API func (c *Command) addFlag(value reflect.Value, field reflect.StructField) error { name, ok := field.Tag.Lookup("cli") if !ok || name == "" { return nil } if len(c.prefix) > 0 { name = c.prefix + name } usage, _ := field.Tag.Lookup("usage") if value.Type().PkgPath() == "time" && value.Type().Name() == "Duration" { c.FlagSet.DurationVar( (*time.Duration)(unsafe.Pointer(value.UnsafeAddr())), name, time.Duration(0), usage, ) return nil } unmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() valuePtr := reflect.New(value.Type()) if valuePtr.Type().Implements(unmarshalerType) { c.FlagSet.Func(name, usage, func(s string) error { if decoder, ok := valuePtr.Interface().(encoding.TextUnmarshaler); ok { if err := decoder.UnmarshalText([]byte(s)); err != nil { return fmt.Errorf("unmarshal text: %w", err) } value.Set(valuePtr.Elem()) return nil } return ErrInternalError }, ) return nil } kind := value.Kind() switch kind { case reflect.Bool: c.FlagSet.BoolVar( (*bool)(unsafe.Pointer(value.UnsafeAddr())), name, false, usage, ) return nil case reflect.String, reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: anyValue := newAnyValue(value) c.FlagSet.Var(anyValue, name, usage) case reflect.Slice: sliceValue := newSliceValue(value, field.Tag.Get("separator")) c.FlagSet.Var(sliceValue, name, usage) default: return fmt.Errorf("%w: %s", errUnsupportedType, kind.String()) } return nil } // createSubCommand creates sub-commands func (c *Command) createSubCommand(tag reflect.StructTag) *Command { name, ok := tag.Lookup("cli") if !ok || name == "" { return c } cmd := Command{SubCommands: make(map[string]*Command)} cmd.Name = name cmd.prefix = c.prefix + name + "-" cmd.FlagSet = c.FlagSet cmd.Usage, _ = tag.Lookup("usage") if usageHandler != nil { cmd.FlagSet.Usage = usageHandler(&cmd) } c.SubCommands[name] = &cmd return &cmd } // Parse parses values from command line and save values into given structure. // The Init(interface{}) function must be called before parsing func (c *Command) Parse(args []string) error { if err := c.FlagSet.Parse(args); err != nil { return fmt.Errorf("parse flag set: %w", err) } c.Args = c.FlagSet.Args() return nil } // PrintUsage prints command description func (c *Command) PrintUsage(w io.Writer) error { if _, err := w.Write([]byte(c.Usage + "\n")); err != nil { return fmt.Errorf("write usage: %w", err) } c.FlagSet.SetOutput(w) c.FlagSet.Usage() return nil }