package packer import ( "fmt" "math" "reflect" "strings" "github.com/graph-gophers/graphql-go/decode" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/types" ) type packer interface { Pack(value interface{}) (reflect.Value, error) } type Builder struct { packerMap map[typePair]*packerMapEntry structPackers []*StructPacker } type typePair struct { graphQLType types.Type resolverType reflect.Type } type packerMapEntry struct { packer packer targets []*packer } func NewBuilder() *Builder { return &Builder{ packerMap: make(map[typePair]*packerMapEntry), } } func (b *Builder) Finish() error { for _, entry := range b.packerMap { for _, target := range entry.targets { *target = entry.packer } } for _, p := range b.structPackers { p.defaultStruct = reflect.New(p.structType).Elem() for _, f := range p.fields { if defaultVal := f.field.Default; defaultVal != nil { v, err := f.fieldPacker.Pack(defaultVal.Deserialize(nil)) if err != nil { return err } p.defaultStruct.FieldByIndex(f.fieldIndex).Set(v) } } } return nil } func (b *Builder) assignPacker(target *packer, schemaType types.Type, reflectType reflect.Type) error { k := typePair{schemaType, reflectType} ref, ok := b.packerMap[k] if !ok { ref = &packerMapEntry{} b.packerMap[k] = ref var err error ref.packer, err = b.makePacker(schemaType, reflectType) if err != nil { return err } } ref.targets = append(ref.targets, target) return nil } func (b *Builder) makePacker(schemaType types.Type, reflectType reflect.Type) (packer, error) { t, nonNull := unwrapNonNull(schemaType) if !nonNull { if reflectType.Kind() == reflect.Ptr { elemType := reflectType.Elem() addPtr := true if _, ok := t.(*types.InputObject); ok { elemType = reflectType // keep pointer for input objects addPtr = false } elem, err := b.makeNonNullPacker(t, elemType) if err != nil { return nil, err } return &nullPacker{ elemPacker: elem, valueType: reflectType, addPtr: addPtr, }, nil } else if isNullable(reflectType) { elemType := reflectType addPtr := false elem, err := b.makeNonNullPacker(t, elemType) if err != nil { return nil, err } return &nullPacker{ elemPacker: elem, valueType: reflectType, addPtr: addPtr, }, nil } else { return nil, fmt.Errorf("%s is not a pointer or a nullable type", reflectType) } } return b.makeNonNullPacker(t, reflectType) } func (b *Builder) makeNonNullPacker(schemaType types.Type, reflectType reflect.Type) (packer, error) { if u, ok := reflect.New(reflectType).Interface().(decode.Unmarshaler); ok { if !u.ImplementsGraphQLType(schemaType.String()) { return nil, fmt.Errorf("can not unmarshal %s into %s", schemaType, reflectType) } return &unmarshalerPacker{ ValueType: reflectType, }, nil } switch t := schemaType.(type) { case *types.ScalarTypeDefinition: return &ValuePacker{ ValueType: reflectType, }, nil case *types.EnumTypeDefinition: if reflectType.Kind() != reflect.String { return nil, fmt.Errorf("wrong type, expected %s", reflect.String) } return &ValuePacker{ ValueType: reflectType, }, nil case *types.InputObject: e, err := b.MakeStructPacker(t.Values, reflectType) if err != nil { return nil, err } return e, nil case *types.List: if reflectType.Kind() != reflect.Slice { return nil, fmt.Errorf("expected slice, got %s", reflectType) } p := &listPacker{ sliceType: reflectType, } if err := b.assignPacker(&p.elem, t.OfType, reflectType.Elem()); err != nil { return nil, err } return p, nil case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return nil, fmt.Errorf("type of kind %s can not be used as input", t.Kind()) default: panic("unreachable") } } func (b *Builder) MakeStructPacker(values []*types.InputValueDefinition, typ reflect.Type) (*StructPacker, error) { structType := typ usePtr := false if typ.Kind() == reflect.Ptr { structType = typ.Elem() usePtr = true } if structType.Kind() != reflect.Struct { return nil, fmt.Errorf("expected struct or pointer to struct, got %s (hint: missing `args struct { ... }` wrapper for field arguments?)", typ) } var fields []*structPackerField for _, v := range values { fe := &structPackerField{field: v} fx := func(n string) bool { return strings.EqualFold(stripUnderscore(n), stripUnderscore(v.Name.Name)) } sf, ok := structType.FieldByNameFunc(fx) if !ok { return nil, fmt.Errorf("%s does not define field %q (hint: missing `args struct { ... }` wrapper for field arguments, or missing field on input struct)", typ, v.Name.Name) } if sf.PkgPath != "" { return nil, fmt.Errorf("field %q must be exported", sf.Name) } fe.fieldIndex = sf.Index ft := v.Type if v.Default != nil { ft, _ = unwrapNonNull(ft) ft = &types.NonNull{OfType: ft} } if err := b.assignPacker(&fe.fieldPacker, ft, sf.Type); err != nil { return nil, fmt.Errorf("field %q: %s", sf.Name, err) } fields = append(fields, fe) } p := &StructPacker{ structType: structType, usePtr: usePtr, fields: fields, } b.structPackers = append(b.structPackers, p) return p, nil } type StructPacker struct { structType reflect.Type usePtr bool defaultStruct reflect.Value fields []*structPackerField } type structPackerField struct { field *types.InputValueDefinition fieldIndex []int fieldPacker packer } func (p *StructPacker) Pack(value interface{}) (reflect.Value, error) { if value == nil { return reflect.Value{}, errors.Errorf("got null for non-null") } values := value.(map[string]interface{}) v := reflect.New(p.structType) v.Elem().Set(p.defaultStruct) for _, f := range p.fields { if value, ok := values[f.field.Name.Name]; ok { packed, err := f.fieldPacker.Pack(value) if err != nil { return reflect.Value{}, err } v.Elem().FieldByIndex(f.fieldIndex).Set(packed) } } if !p.usePtr { return v.Elem(), nil } return v, nil } type listPacker struct { sliceType reflect.Type elem packer } func (e *listPacker) Pack(value interface{}) (reflect.Value, error) { list, ok := value.([]interface{}) if !ok { list = []interface{}{value} } v := reflect.MakeSlice(e.sliceType, len(list), len(list)) for i := range list { packed, err := e.elem.Pack(list[i]) if err != nil { return reflect.Value{}, err } v.Index(i).Set(packed) } return v, nil } type nullPacker struct { elemPacker packer valueType reflect.Type addPtr bool } func (p *nullPacker) Pack(value interface{}) (reflect.Value, error) { if value == nil && !isNullable(p.valueType) { return reflect.Zero(p.valueType), nil } v, err := p.elemPacker.Pack(value) if err != nil { return reflect.Value{}, err } if p.addPtr { ptr := reflect.New(p.valueType.Elem()) ptr.Elem().Set(v) return ptr, nil } return v, nil } type ValuePacker struct { ValueType reflect.Type } func (p *ValuePacker) Pack(value interface{}) (reflect.Value, error) { if value == nil { return reflect.Value{}, errors.Errorf("got null for non-null") } coerced, err := unmarshalInput(p.ValueType, value) if err != nil { return reflect.Value{}, fmt.Errorf("could not unmarshal %#v (%T) into %s: %s", value, value, p.ValueType, err) } return reflect.ValueOf(coerced), nil } type unmarshalerPacker struct { ValueType reflect.Type } func (p *unmarshalerPacker) Pack(value interface{}) (reflect.Value, error) { if value == nil && !isNullable(p.ValueType) { return reflect.Value{}, errors.Errorf("got null for non-null") } v := reflect.New(p.ValueType) if err := v.Interface().(decode.Unmarshaler).UnmarshalGraphQL(value); err != nil { return reflect.Value{}, err } return v.Elem(), nil } func unmarshalInput(typ reflect.Type, input interface{}) (interface{}, error) { if reflect.TypeOf(input) == typ { return input, nil } switch typ.Kind() { case reflect.Int32: switch input := input.(type) { case int: if input < math.MinInt32 || input > math.MaxInt32 { return nil, fmt.Errorf("not a 32-bit integer") } return int32(input), nil case float64: coerced := int32(input) if input < math.MinInt32 || input > math.MaxInt32 || float64(coerced) != input { return nil, fmt.Errorf("not a 32-bit integer") } return coerced, nil } case reflect.Float64: switch input := input.(type) { case int32: return float64(input), nil case int: return float64(input), nil } case reflect.String: if reflect.TypeOf(input).ConvertibleTo(typ) { return reflect.ValueOf(input).Convert(typ).Interface(), nil } } return nil, fmt.Errorf("incompatible type") } func unwrapNonNull(t types.Type) (types.Type, bool) { if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false } func stripUnderscore(s string) string { return strings.Replace(s, "_", "", -1) } // NullUnmarshaller is an unmarshaller that can handle a nil input type NullUnmarshaller interface { decode.Unmarshaler Nullable() } func isNullable(t reflect.Type) bool { _, ok := reflect.New(t).Interface().(NullUnmarshaller) return ok }