Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling pointer struct fields of all types as quasi options via 2-type union [null, ...] #20

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ thumbs.db
*fuzz.zip
fuzzes/**/crashers
fuzzes/**/suppressions

# IntelliJ
*.iml
25 changes: 20 additions & 5 deletions datum_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,18 @@ func (reader sDatumReader) mapArray(field Schema, reflectField reflect.Value, de
return reflect.ValueOf(arrayLength), err
}

array := reflect.MakeSlice(reflectField.Type(), 0, 0)
indirectArrayType := reflectField.Type()
if reflectField.Type().Kind() == reflect.Ptr {
indirectArrayType = indirectArrayType.Elem()
}
array := reflect.MakeSlice(indirectArrayType, 0, 0)
pointer := reflectField.Type().Elem().Kind() == reflect.Ptr
for {
if arrayLength == 0 {
break
}

arrayPart := reflect.MakeSlice(reflectField.Type(), int(arrayLength), int(arrayLength))
arrayPart := reflect.MakeSlice(indirectArrayType, int(arrayLength), int(arrayLength))
var i int64
for ; i < arrayLength; i++ {
current := arrayPart.Index(int(i))
Expand Down Expand Up @@ -282,8 +286,12 @@ func (reader sDatumReader) mapMap(field Schema, reflectField reflect.Value, dec
return reflect.ValueOf(mapLength), err
}
elemType := reflectField.Type().Elem()
elemIsPointer := (elemType.Kind() == reflect.Ptr)
resultMap := reflect.MakeMap(reflectField.Type())
elemIsPointer := elemType.Kind() == reflect.Ptr
indirectMapType := reflectField.Type()
if reflectField.Type().Kind() == reflect.Ptr {
indirectMapType = indirectMapType.Elem()
}
resultMap := reflect.MakeMap(indirectMapType)

// dest is an element type value used as the destination for reading values into.
// This is required for using non-primitive types as map values, because map values are not addressable
Expand Down Expand Up @@ -360,7 +368,14 @@ func (reader sDatumReader) mapUnion(field Schema, reflectField reflect.Value, de
if unionIndex < 0 || int(unionIndex) >= len(types) {
return reflect.Value{}, fmt.Errorf("Invalid union index %d", unionIndex)
}
return reader.readValue(types[unionIndex], reflectField, dec)

value, err := reader.readValue(types[unionIndex], reflectField, dec)
if reflectField.Kind() == reflect.Ptr && value.Kind() != reflect.Ptr && value.IsValid() {
ref := reflect.New(reflectField.Type().Elem())
ref.Elem().Set(value)
value = ref
}
return value, err
}

func (reader sDatumReader) mapFixed(field Schema, dec Decoder) (reflect.Value, error) {
Expand Down
51 changes: 40 additions & 11 deletions datum_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (writer *SpecificDatumWriter) writeBoolean(v reflect.Value, enc Encoder, s
return fmt.Errorf("Invalid boolean value: %v", v.Interface())
}

enc.WriteBoolean(v.Interface().(bool))
enc.WriteBoolean(dereference(v).Interface().(bool))
return nil
}

Expand All @@ -154,7 +154,7 @@ func (writer *SpecificDatumWriter) writeInt(v reflect.Value, enc Encoder, s Sche
return fmt.Errorf("Invalid int value: %v", v.Interface())
}

enc.WriteInt(v.Interface().(int32))
enc.WriteInt(dereference(v).Interface().(int32))
return nil
}

Expand All @@ -163,7 +163,7 @@ func (writer *SpecificDatumWriter) writeLong(v reflect.Value, enc Encoder, s Sch
return fmt.Errorf("Invalid long value: %v", v.Interface())
}

enc.WriteLong(v.Interface().(int64))
enc.WriteLong(dereference(v).Interface().(int64))
return nil
}

Expand All @@ -172,7 +172,7 @@ func (writer *SpecificDatumWriter) writeFloat(v reflect.Value, enc Encoder, s Sc
return fmt.Errorf("Invalid float value: %v", v.Interface())
}

enc.WriteFloat(v.Interface().(float32))
enc.WriteFloat(dereference(v).Interface().(float32))
return nil
}

Expand All @@ -181,7 +181,7 @@ func (writer *SpecificDatumWriter) writeDouble(v reflect.Value, enc Encoder, s S
return fmt.Errorf("Invalid double value: %v", v.Interface())
}

enc.WriteDouble(v.Interface().(float64))
enc.WriteDouble(dereference(v).Interface().(float64))
return nil
}

Expand All @@ -190,7 +190,7 @@ func (writer *SpecificDatumWriter) writeBytes(v reflect.Value, enc Encoder, s Sc
return fmt.Errorf("Invalid bytes value: %v", v.Interface())
}

enc.WriteBytes(v.Interface().([]byte))
enc.WriteBytes(dereference(v).Interface().([]byte))
return nil
}

Expand All @@ -199,14 +199,17 @@ func (writer *SpecificDatumWriter) writeString(v reflect.Value, enc Encoder, s S
return fmt.Errorf("Invalid string value: %v", v.Interface())
}

enc.WriteString(v.Interface().(string))
enc.WriteString(dereference(v).Interface().(string))
return nil
}

func (writer *SpecificDatumWriter) writeArray(v reflect.Value, enc Encoder, s Schema) error {
if !s.Validate(v) {
return fmt.Errorf("Invalid array value: %v", v.Interface())
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

if v.Len() == 0 {
enc.WriteArrayNext(0)
Expand All @@ -229,7 +232,9 @@ func (writer *SpecificDatumWriter) writeMap(v reflect.Value, enc Encoder, s Sche
if !s.Validate(v) {
return fmt.Errorf("Invalid map value: %v", v.Interface())
}

if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Len() == 0 {
enc.WriteMapNext(0)
return nil
Expand Down Expand Up @@ -280,7 +285,7 @@ func (writer *SpecificDatumWriter) writeFixed(v reflect.Value, enc Encoder, s Sc
}

// Write the raw bytes. The length is known by the schema
enc.WriteRaw(v.Interface().([]byte))
enc.WriteRaw(dereference(v).Interface().([]byte))
return nil
}

Expand Down Expand Up @@ -370,6 +375,8 @@ func (writer *GenericDatumWriter) writeBoolean(v interface{}, enc Encoder) error
switch value := v.(type) {
case bool:
enc.WriteBoolean(value)
case *bool:
enc.WriteBoolean(*value)
default:
return fmt.Errorf("%v is not a boolean", v)
}
Expand All @@ -381,6 +388,8 @@ func (writer *GenericDatumWriter) writeInt(v interface{}, enc Encoder) error {
switch value := v.(type) {
case int32:
enc.WriteInt(value)
case *int32:
enc.WriteInt(*value)
default:
return fmt.Errorf("%v is not an int32", v)
}
Expand All @@ -392,6 +401,8 @@ func (writer *GenericDatumWriter) writeLong(v interface{}, enc Encoder) error {
switch value := v.(type) {
case int64:
enc.WriteLong(value)
case *int64:
enc.WriteLong(*value)
default:
return fmt.Errorf("%v is not an int64", v)
}
Expand All @@ -403,6 +414,8 @@ func (writer *GenericDatumWriter) writeFloat(v interface{}, enc Encoder) error {
switch value := v.(type) {
case float32:
enc.WriteFloat(value)
case *float32:
enc.WriteFloat(*value)
default:
return fmt.Errorf("%v is not a float32", v)
}
Expand All @@ -414,6 +427,8 @@ func (writer *GenericDatumWriter) writeDouble(v interface{}, enc Encoder) error
switch value := v.(type) {
case float64:
enc.WriteDouble(value)
case *float64:
enc.WriteDouble(*value)
default:
return fmt.Errorf("%v is not a float64", v)
}
Expand All @@ -425,6 +440,8 @@ func (writer *GenericDatumWriter) writeBytes(v interface{}, enc Encoder) error {
switch value := v.(type) {
case []byte:
enc.WriteBytes(value)
case *[]byte:
enc.WriteBytes(*value)
default:
return fmt.Errorf("%v is not a []byte", v)
}
Expand All @@ -436,6 +453,8 @@ func (writer *GenericDatumWriter) writeString(v interface{}, enc Encoder) error
switch value := v.(type) {
case string:
enc.WriteString(value)
case *string:
enc.WriteString(*value)
default:
return fmt.Errorf("%v is not a string", v)
}
Expand All @@ -445,6 +464,9 @@ func (writer *GenericDatumWriter) writeString(v interface{}, enc Encoder) error

func (writer *GenericDatumWriter) writeArray(v interface{}, enc Encoder, s Schema) error {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
if rv.Kind() != reflect.Slice && rv.Kind() != reflect.Array {
return errors.New("Not a slice or array type")
}
Expand All @@ -469,6 +491,9 @@ func (writer *GenericDatumWriter) writeArray(v interface{}, enc Encoder, s Schem

func (writer *GenericDatumWriter) writeMap(v interface{}, enc Encoder, s Schema) error {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
if rv.Kind() != reflect.Map {
return errors.New("Not a map type")
}
Expand Down Expand Up @@ -584,9 +609,13 @@ func (writer *GenericDatumWriter) writeFixed(v interface{}, enc Encoder, s Schem
if !fs.Validate(reflect.ValueOf(v)) {
return fmt.Errorf("Invalid fixed value: %v", v)
}
switch value := v.(type) {
case []byte:
enc.WriteRaw(value)
case *[]byte:
enc.WriteRaw(*value)
}

// Write the raw bytes. The length is known by the schema
enc.WriteRaw(v.([]byte))
return nil
}

Expand Down
4 changes: 3 additions & 1 deletion generic_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ limitations under the License. */

package avro

import "encoding/json"
import (
"encoding/json"
)

// AvroRecord is an interface for anything that has an Avro schema and can be serialized/deserialized by this library.
type AvroRecord interface {
Expand Down
Loading