Skip to content

Commit c18ce7b

Browse files
committed
fix varint encode bug, add buffer pool.
1 parent ccc8f84 commit c18ce7b

File tree

2 files changed

+141
-55
lines changed

2 files changed

+141
-55
lines changed

pb.c

+12-7
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,11 @@ static void pb_prepbuffer(pb_Buffer *buff, size_t need) {
270270

271271
static void pb_addvarint(pb_Buffer *buff, uint64_t n) {
272272
pb_prepbuffer(buff, 10);
273-
while (n != 0) {
273+
do {
274274
int cur = n & 0x7F;
275275
n >>= 7;
276276
pb_addchar(buff, n != 0 ? cur | 0x80 : cur);
277-
}
277+
} while (n != 0);
278278
}
279279

280280
static void pb_addfixed32(pb_Buffer *buff, uint32_t n) {
@@ -836,11 +836,16 @@ static int Ldec_finished(lua_State *L) {
836836
static int Ldec_tag(lua_State *L) {
837837
pb_Decoder *dec = (pb_Decoder*)luaL_checkudata(L, 1, PB_DECODER);
838838
uint64_t n = 0;
839+
int wiretype;
839840
if (!pb_readvarint(dec, &n)) return 0;
841+
wiretype = (int)(n & 0x7);
840842
lua_pushinteger(L, (lua_Integer)(n >> 3));
841-
lua_pushinteger(L, (lua_Integer)(n & 0x7));
842-
lua_pushstring(L, pb_wiretypes[n & 0x7]);
843-
return 3;
843+
lua_pushinteger(L, (lua_Integer)wiretype);
844+
if (wiretype >= 0 && wiretype < PB_TWCOUNT) {
845+
lua_pushstring(L, pb_wiretypes[wiretype]);
846+
return 3;
847+
}
848+
return 2;
844849
}
845850

846851
static int Ldec_varint(lua_State *L) {
@@ -935,8 +940,8 @@ static int skipvalue(pb_FBDecoder *dec, int wiretype) {
935940

936941
static int Ldec_fetch(lua_State *L) {
937942
pb_FBDecoder dec = check_fbdecoder(L, 1);
938-
int wiretype, extra = get_wiretype(L, dec.dec, 2, &wiretype);
939943
int type = find_type(luaL_optstring(L, 3, NULL));
944+
int wiretype, extra = get_wiretype(L, dec.dec, 2, &wiretype);
940945
if (extra >= 0 && pb_pushscalar(&dec, wiretype, type))
941946
return extra + 1;
942947
restore_decoder(&dec);
@@ -1079,6 +1084,6 @@ LUALIB_API int luaopen_pb_io(lua_State *L) {
10791084
return 1;
10801085
}
10811086

1082-
/* cc: flags+='-mdll -s -O3 -DLUA_BUILD_AS_DLL'
1087+
/* cc: flags+='-s -O3 -mdll -DLUA_BUILD_AS_DLL'
10831088
* xcc: flags+='-ID:\luajit\include' libs+='-LD:\luajit\'
10841089
* cc: output='pb.dll' libs+='-llua53' */

pb.lua

+129-48
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ local function field_type(field)
3333
return realtype
3434
end
3535

36-
local decode, decode_field, decode_unknown_field
37-
local encode
36+
local decode, encode
3837

39-
function decode_unknown_field(t, dec, wiretype, tag)
38+
local function decode_unknown_field(t, dec, wiretype, tag)
4039
local value = dec:fetch(wiretype)
4140
local uf = t.unknown_fields
4241
if not uf then
@@ -53,7 +52,7 @@ function decode_unknown_field(t, dec, wiretype, tag)
5352
end
5453
end
5554

56-
function decode_field(t, dec, wiretype, tag, field)
55+
local function decode_field(t, dec, wiretype, tag, field)
5756
local value
5857
if field.scalar then
5958
value = dec:fetch(wiretype, field.type_name)
@@ -68,7 +67,7 @@ function decode_field(t, dec, wiretype, tag, field)
6867
else
6968
local len = dec:fetch "varint"
7069
local old = dec:len(dec:pos() + len - 1)
71-
value = decode(dec, ftype)
70+
value = decode(dec, ftype, table.concat(field.type_name, "."))
7271
dec:len(old)
7372
end
7473
end
@@ -83,8 +82,9 @@ function decode_field(t, dec, wiretype, tag, field)
8382
end
8483
end
8584

86-
function decode(dec, ptype)
85+
function decode(dec, ptype, tn)
8786
local t = {}
87+
local pos = dec:pos()
8888
while not dec:finished() do
8989
local tag, wiretype = dec:tag()
9090
local field = ptype[tag]
@@ -94,52 +94,131 @@ function decode(dec, ptype)
9494
decode_unknown_field(t, dec, wiretype, tag)
9595
end
9696
end
97+
local size = dec:pos() - pos
9798
return t
9899
end
99100

101+
local buffer_pool = {}
102+
local buffer_used = setmetatable({}, { __mode="k" })
103+
104+
105+
--[[
106+
function buffer.new()
107+
local t = {}
108+
function t:add(tag, type, value)
109+
t[#t+1] = ("[%s %d %s]\n"):format(type, tag, tostring(value))
110+
end
111+
function t:tag(tag, wiretype)
112+
t[#t+1] = ("[%s %d "):format(wiretype, tag)
113+
end
114+
function t:varint(n)
115+
t[#t+1] = ("%d]\n"):format(n)
116+
end
117+
function t:bytes(s)
118+
if type(s) == "table" then
119+
s = table.concat(s)
120+
t[#t+1] = ("\n"..s.."]"):gsub("\n", "\n "):gsub(" ]%s*$", "]\n")
121+
else
122+
t[#t+1] = ("'%s'\n]\n"):format(s)
123+
end
124+
end
125+
function t:clear(len, result)
126+
if result then
127+
result = table.concat(t)
128+
end
129+
for k, v in ipairs(t) do
130+
t[k] = nil
131+
end
132+
return result
133+
end
134+
return t
135+
end
136+
--]]
137+
138+
local function get_buffer()
139+
local buff = next(buffer_pool)
140+
if buff then
141+
buffer_pool[buff] = nil
142+
else
143+
buff = buffer.new()
144+
end
145+
buffer_used[buff] = true
146+
return buff
147+
end
148+
149+
local function put_buffer(buff)
150+
buffer_used[buff] = nil
151+
buffer_pool[buff] = true
152+
end
153+
154+
local function encode_message(buff, tag, msg, ftype)
155+
local inner = get_buffer()
156+
inner:clear()
157+
encode(inner, msg, ftype)
158+
buff:tag(tag, "bytes")
159+
buff:bytes(inner)
160+
inner:clear()
161+
put_buffer(inner)
162+
end
163+
164+
local function encode_enum(buff, tag, enum, ftype)
165+
local value = assert(ftype.map[enum])
166+
buff:tag(tag, "varint")
167+
buff:varint(value)
168+
end
169+
170+
local function encode_field(buff, tag, v, ptype)
171+
--print(("encode_field(%d, %s)"):format(tag,
172+
--require"serpent".block(v)))
173+
local field = ptype[tag]
174+
if not field then return end
175+
176+
if field.scalar then
177+
-- TODO packed repeated
178+
if field.repeated then
179+
for k,v in ipairs(v) do
180+
buff:add(tag, field.type_name, v)
181+
end
182+
else
183+
buff:add(tag, field.type_name, v)
184+
end
185+
return
186+
end
187+
188+
local ftype = field_type(field)
189+
if ftype.type == "message" then
190+
if not inner_buff then
191+
inner_buff = buffer.new()
192+
end
193+
if not field.repeated then
194+
encode_message(buff, tag, v, ftype)
195+
else
196+
for _, v in ipairs(v) do
197+
encode_message(buff, tag, v, ftype)
198+
end
199+
end
200+
return
201+
end
202+
203+
if ftype.type == "enum" then
204+
if not field.repeated then
205+
encode_enum(buff, tag, v, ftype)
206+
else
207+
for _, v in ipairs(v) do
208+
encode_enum(buff, tag, v, ftype)
209+
end
210+
end
211+
return
212+
end
213+
214+
error("unknown type: "..ftype.type)
215+
end
216+
100217
function encode(buff, t, ptype)
101-
lvl = lvl or 1
102-
local lvls = (" "):rep(lvl)
103-
local inner_buff
104218
for k,v in pairs(t) do
105219
local tag = ptype.map[k]
106-
local field = ptype[tag]
107-
if field then
108-
if field.scalar then
109-
if field.repeated then
110-
for k,v in ipairs(v) do
111-
buff:add(tag, field.type_name, v)
112-
end
113-
else
114-
buff:add(tag, field.type_name, v)
115-
end
116-
else
117-
local ftype = field_type(field)
118-
if ftype.type == "message" then
119-
if not inner_buff then
120-
inner_buff = buffer.new()
121-
end
122-
if field.repeated then
123-
for k,v in ipairs(v) do
124-
inner_buff:clear()
125-
encode(inner_buff, v, ftype)
126-
buff:tag(tag, "bytes")
127-
buff:bytes(inner_buff)
128-
end
129-
else
130-
inner_buff:clear()
131-
encode(inner_buff, v, ftype)
132-
buff:tag(tag, "bytes")
133-
buff:bytes(inner_buff)
134-
end
135-
elseif ftype.type == "enum" then
136-
local value = ftype.map[v]
137-
if value then
138-
buff:tag(tag, "varint")
139-
buff:varint(value)
140-
end
141-
end
142-
end
220+
if tag then
221+
encode_field(buff, tag, v, ptype)
143222
end
144223
end
145224
end
@@ -153,12 +232,14 @@ function pb.decode(s, ptype)
153232
return res
154233
end
155234

156-
local buff = buffer.new()
157235
function pb.encode(t, ptype)
158236
local realtype = qualitied_type(ptype)
237+
local buff = get_buffer()
159238
buff:clear()
160239
encode(buff, t, realtype)
161-
return buff:clear(nil, true)
240+
local res = buff:clear(nil, true)
241+
put_buffer(buff)
242+
return res
162243
end
163244

164245
------------------------------------------------------------

0 commit comments

Comments
 (0)