From d579c53104adace8275889d2b03c4432a79c4319 Mon Sep 17 00:00:00 2001 From: John Howard Date: Thu, 18 Jan 2024 09:14:33 -0800 Subject: [PATCH] Fix marshalling of empty oneOf messages Fixes https://github.com/planetscale/vtprotobuf/issues/61 --- .../internal/conformance/oneof_test.go | 16 ++++++ .../test_messages_proto2_vtproto.pb.go | 14 ++++++ .../test_messages_proto3_vtproto.pb.go | 14 ++++++ features/marshal/marshalto.go | 13 +++-- features/size/size.go | 12 +++-- testproto/pool/pool_with_oneof_vtproto.pb.go | 30 +++++++++++ testproto/unsafe/unsafe_vtproto.pb.go | 50 +++++++++++++++++++ types/known/structpb/struct_vtproto.pb.go | 20 ++++++++ 8 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 conformance/internal/conformance/oneof_test.go diff --git a/conformance/internal/conformance/oneof_test.go b/conformance/internal/conformance/oneof_test.go new file mode 100644 index 00000000..af309657 --- /dev/null +++ b/conformance/internal/conformance/oneof_test.go @@ -0,0 +1,16 @@ +package conformance + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestEmptyOneoff(t *testing.T) { + // Regression test for https://github.com/planetscale/vtprotobuf/issues/61 + msg := &TestAllTypesProto3{OneofField: &TestAllTypesProto3_OneofNestedMessage{}} + upstream, _ := proto.Marshal(msg) + vt, _ := msg.MarshalVTStrict() + require.Equal(t, upstream, vt) +} diff --git a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go index afc122f9..c8eb7a3a 100644 --- a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go @@ -3918,6 +3918,12 @@ func (m *TestAllTypesProto2_OneofNestedMessage) MarshalToSizedBufferVT(dAtA []by dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -6038,6 +6044,12 @@ func (m *TestAllTypesProto2_OneofNestedMessage) MarshalToSizedBufferVTStrict(dAt dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -7094,6 +7106,8 @@ func (m *TestAllTypesProto2_OneofNestedMessage) SizeVT() (n int) { if m.OneofNestedMessage != nil { l = m.OneofNestedMessage.SizeVT() n += 2 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go index 7de867ea..aca5008c 100644 --- a/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go @@ -4113,6 +4113,12 @@ func (m *TestAllTypesProto3_OneofNestedMessage) MarshalToSizedBufferVT(dAtA []by dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -6317,6 +6323,12 @@ func (m *TestAllTypesProto3_OneofNestedMessage) MarshalToSizedBufferVTStrict(dAt dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -7286,6 +7298,8 @@ func (m *TestAllTypesProto3_OneofNestedMessage) SizeVT() (n int) { if m.OneofNestedMessage != nil { l = m.OneofNestedMessage.SizeVT() n += 2 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/features/marshal/marshalto.go b/features/marshal/marshalto.go index 9f323ced..b65f9b5d 100644 --- a/features/marshal/marshalto.go +++ b/features/marshal/marshalto.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/planetscale/vtprotobuf/generator" - "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/reflect/protoreflect" @@ -520,7 +519,14 @@ func (p *marshal) field(oneof bool, numGen *counter, field *protogen.Field) { default: panic("not implemented") } - if repeated || nullable { + // Empty protobufs should emit a message or compatibility with Golang protobuf; + // See https://github.com/planetscale/vtprotobuf/issues/61 + if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() { + p.P("} else {") + p.P("i = protohelpers.EncodeVarint(dAtA, i, 0)") + p.encodeKey(fieldNumber, wireType) + p.P("}") + } else if repeated || nullable { p.P(`}`) } } @@ -676,7 +682,7 @@ func (p *marshal) message(message *protogen.Message) { p.P(`}`) p.P() - //Generate MarshalToVT methods for oneof fields + // Generate MarshalToVT methods for oneof fields for _, field := range message.Fields { if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() { continue @@ -709,7 +715,6 @@ func (p *marshal) marshalBackwardSize(varInt bool) { if varInt { p.encodeVarint(`size`) } - } func (p *marshal) marshalBackward(varName string, varInt bool, message *protogen.Message) { diff --git a/features/size/size.go b/features/size/size.go index 61bf3da1..54d856a4 100644 --- a/features/size/size.go +++ b/features/size/size.go @@ -8,11 +8,10 @@ package size import ( "strconv" + "github.com/planetscale/vtprotobuf/generator" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/reflect/protoreflect" - - "github.com/planetscale/vtprotobuf/generator" ) func init() { @@ -266,7 +265,12 @@ func (p *size) field(oneof bool, field *protogen.Field, sizeName string) { default: panic("not implemented") } - if repeated || nullable { + // Empty protobufs should emit a message or compatibility with Golang protobuf; + // See https://github.com/planetscale/vtprotobuf/issues/61 + // Size is always 3 so just hardcode that here + if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() { + p.P("} else { n += 3 }") + } else if repeated || nullable { p.P(`}`) } } @@ -310,8 +314,6 @@ func (p *size) message(message *protogen.Message) { } p.P(`}`) } else { - //if _, ok := oneofs[fieldname]; !ok { - //oneofs[fieldname] = struct{}{} p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{ SizeVT() int }); ok {`) p.P(`n+=vtmsg.`, sizeName, `()`) p.P(`}`) diff --git a/testproto/pool/pool_with_oneof_vtproto.pb.go b/testproto/pool/pool_with_oneof_vtproto.pb.go index 1f95bddd..d781b2a7 100644 --- a/testproto/pool/pool_with_oneof_vtproto.pb.go +++ b/testproto/pool/pool_with_oneof_vtproto.pb.go @@ -571,6 +571,10 @@ func (m *OneofTest_Test1_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -590,6 +594,10 @@ func (m *OneofTest_Test2_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -609,6 +617,10 @@ func (m *OneofTest_Test3_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -864,6 +876,10 @@ func (m *OneofTest_Test1_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -883,6 +899,10 @@ func (m *OneofTest_Test2_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -902,6 +922,10 @@ func (m *OneofTest_Test3_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -1114,6 +1138,8 @@ func (m *OneofTest_Test1_) SizeVT() (n int) { if m.Test1 != nil { l = m.Test1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1126,6 +1152,8 @@ func (m *OneofTest_Test2_) SizeVT() (n int) { if m.Test2 != nil { l = m.Test2.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1138,6 +1166,8 @@ func (m *OneofTest_Test3_) SizeVT() (n int) { if m.Test3 != nil { l = m.Test3.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/testproto/unsafe/unsafe_vtproto.pb.go b/testproto/unsafe/unsafe_vtproto.pb.go index d03d22e5..835f096b 100644 --- a/testproto/unsafe/unsafe_vtproto.pb.go +++ b/testproto/unsafe/unsafe_vtproto.pb.go @@ -880,6 +880,10 @@ func (m *UnsafeTest_Sub1_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -899,6 +903,10 @@ func (m *UnsafeTest_Sub2_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -918,6 +926,10 @@ func (m *UnsafeTest_Sub3_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -937,6 +949,10 @@ func (m *UnsafeTest_Sub4_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x22 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x22 } return len(dAtA) - i, nil } @@ -956,6 +972,10 @@ func (m *UnsafeTest_Sub5_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x2a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x2a } return len(dAtA) - i, nil } @@ -1320,6 +1340,10 @@ func (m *UnsafeTest_Sub1_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -1339,6 +1363,10 @@ func (m *UnsafeTest_Sub2_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -1358,6 +1386,10 @@ func (m *UnsafeTest_Sub3_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -1377,6 +1409,10 @@ func (m *UnsafeTest_Sub4_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x22 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x22 } return len(dAtA) - i, nil } @@ -1396,6 +1432,10 @@ func (m *UnsafeTest_Sub5_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x2a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x2a } return len(dAtA) - i, nil } @@ -1531,6 +1571,8 @@ func (m *UnsafeTest_Sub1_) SizeVT() (n int) { if m.Sub1 != nil { l = m.Sub1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1543,6 +1585,8 @@ func (m *UnsafeTest_Sub2_) SizeVT() (n int) { if m.Sub2 != nil { l = m.Sub2.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1555,6 +1599,8 @@ func (m *UnsafeTest_Sub3_) SizeVT() (n int) { if m.Sub3 != nil { l = m.Sub3.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1567,6 +1613,8 @@ func (m *UnsafeTest_Sub4_) SizeVT() (n int) { if m.Sub4 != nil { l = m.Sub4.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1579,6 +1627,8 @@ func (m *UnsafeTest_Sub5_) SizeVT() (n int) { if m.Sub5 != nil { l = m.Sub5.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/types/known/structpb/struct_vtproto.pb.go b/types/known/structpb/struct_vtproto.pb.go index 835849e1..be8b40e3 100644 --- a/types/known/structpb/struct_vtproto.pb.go +++ b/types/known/structpb/struct_vtproto.pb.go @@ -569,6 +569,10 @@ func (m *Value_StructValue) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x2a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x2a } return len(dAtA) - i, nil } @@ -588,6 +592,10 @@ func (m *Value_ListValue) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x32 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x32 } return len(dAtA) - i, nil } @@ -832,6 +840,10 @@ func (m *Value_StructValue) MarshalToSizedBufferVTStrict(dAtA []byte) (int, erro i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x2a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x2a } return len(dAtA) - i, nil } @@ -851,6 +863,10 @@ func (m *Value_ListValue) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x32 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x32 } return len(dAtA) - i, nil } @@ -986,6 +1002,8 @@ func (m *Value_StructValue) SizeVT() (n int) { if m.StructValue != nil { l = (*Struct)(m.StructValue).SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -998,6 +1016,8 @@ func (m *Value_ListValue) SizeVT() (n int) { if m.ListValue != nil { l = (*ListValue)(m.ListValue).SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n }