decoder.lua 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. --
  2. --------------------------------------------------------------------------------
  3. -- FILE: decoder.lua
  4. -- DESCRIPTION: protoc-gen-lua
  5. -- Google's Protocol Buffers project, ported to lua.
  6. -- https://code.google.com/p/protoc-gen-lua/
  7. --
  8. -- Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
  9. -- All rights reserved.
  10. --
  11. -- Use, modification and distribution are subject to the "New BSD License"
  12. -- as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
  13. --
  14. -- COMPANY: NetEase
  15. -- CREATED: 2010年07月29日 19时30分51秒 CST
  16. --------------------------------------------------------------------------------
  17. --
  18. local string = string
  19. local table = table
  20. local assert = assert
  21. local ipairs = ipairs
  22. local error = error
  23. local print = print
  24. local pb = require "pb"
  25. local encoder = require "protobuf.encoder"
  26. local wire_format = require "protobuf.wire_format"
  27. module "protobuf.decoder"
  28. local _DecodeVarint = pb.varint_decoder
  29. local _DecodeSignedVarint = pb.signed_varint_decoder
  30. local _DecodeVarint32 = pb.varint_decoder
  31. local _DecodeSignedVarint32 = pb.signed_varint_decoder
  32. local _DecodeVarint64 = pb.varint_decoder64
  33. local _DecodeSignedVarint64 = pb.signed_varint_decoder64
  34. ReadTag = pb.read_tag
  35. local function _SimpleDecoder(wire_type, decode_value)
  36. return function(field_number, is_repeated, is_packed, key, new_default)
  37. if is_packed then
  38. local DecodeVarint = _DecodeVarint
  39. return function (buffer, pos, pend, message, field_dict)
  40. local value = field_dict[key]
  41. if value == nil then
  42. value = new_default(message)
  43. field_dict[key] = value
  44. end
  45. local endpoint
  46. endpoint, pos = DecodeVarint(buffer, pos)
  47. endpoint = endpoint + pos
  48. if endpoint > pend then
  49. error('Truncated message.')
  50. end
  51. local element
  52. while pos < endpoint do
  53. element, pos = decode_value(buffer, pos)
  54. value[#value + 1] = element
  55. end
  56. if pos > endpoint then
  57. value:remove(#value)
  58. error('Packed element was truncated.')
  59. end
  60. return pos
  61. end
  62. elseif is_repeated then
  63. local tag_bytes = encoder.TagBytes(field_number, wire_type)
  64. local tag_len = #tag_bytes
  65. local sub = string.sub
  66. return function(buffer, pos, pend, message, field_dict)
  67. local value = field_dict[key]
  68. if value == nil then
  69. value = new_default(message)
  70. field_dict[key] = value
  71. end
  72. while 1 do
  73. local element, new_pos = decode_value(buffer, pos)
  74. value:append(element)
  75. pos = new_pos + tag_len
  76. if sub(buffer, new_pos+1, pos) ~= tag_bytes or new_pos >= pend then
  77. if new_pos > pend then
  78. error('Truncated message.')
  79. end
  80. return new_pos
  81. end
  82. end
  83. end
  84. else
  85. return function (buffer, pos, pend, message, field_dict)
  86. field_dict[key], pos = decode_value(buffer, pos)
  87. if pos > pend then
  88. field_dict[key] = nil
  89. error('Truncated message.')
  90. end
  91. return pos
  92. end
  93. end
  94. end
  95. end
  96. local function _ModifiedDecoder(wire_type, decode_value, modify_value)
  97. local InnerDecode = function (buffer, pos)
  98. local result, new_pos = decode_value(buffer, pos)
  99. return modify_value(result), new_pos
  100. end
  101. return _SimpleDecoder(wire_type, InnerDecode)
  102. end
  103. local function _StructPackDecoder(wire_type, value_size, format)
  104. local struct_unpack = pb.struct_unpack
  105. function InnerDecode(buffer, pos)
  106. local new_pos = pos + value_size
  107. local result = struct_unpack(format, buffer, pos)
  108. return result, new_pos
  109. end
  110. return _SimpleDecoder(wire_type, InnerDecode)
  111. end
  112. local function _Boolean(value)
  113. return value ~= 0
  114. end
  115. Int32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
  116. EnumDecoder = Int32Decoder
  117. Int64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeSignedVarint64)
  118. UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
  119. UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint64)
  120. SInt32Decoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode32)
  121. SInt64Decoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint64, wire_format.ZigZagDecode64)
  122. Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('I'))
  123. Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('Q'))
  124. SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('i'))
  125. SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('q'))
  126. FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, 4, string.byte('f'))
  127. DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, 8, string.byte('d'))
  128. BoolDecoder = _ModifiedDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint, _Boolean)
  129. function StringDecoder(field_number, is_repeated, is_packed, key, new_default)
  130. local DecodeVarint = _DecodeVarint
  131. local sub = string.sub
  132. -- local unicode = unicode
  133. assert(not is_packed)
  134. if is_repeated then
  135. local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
  136. local tag_len = #tag_bytes
  137. return function (buffer, pos, pend, message, field_dict)
  138. local value = field_dict[key]
  139. if value == nil then
  140. value = new_default(message)
  141. field_dict[key] = value
  142. end
  143. while 1 do
  144. local size, new_pos
  145. size, pos = DecodeVarint(buffer, pos)
  146. new_pos = pos + size
  147. if new_pos > pend then
  148. error('Truncated string.')
  149. end
  150. value:append(sub(buffer, pos+1, new_pos))
  151. pos = new_pos + tag_len
  152. if sub(buffer, new_pos + 1, pos) ~= tag_bytes or new_pos == pend then
  153. return new_pos
  154. end
  155. end
  156. end
  157. else
  158. return function (buffer, pos, pend, message, field_dict)
  159. local size, new_pos
  160. size, pos = DecodeVarint(buffer, pos)
  161. new_pos = pos + size
  162. if new_pos > pend then
  163. error('Truncated string.')
  164. end
  165. field_dict[key] = sub(buffer, pos + 1, new_pos)
  166. return new_pos
  167. end
  168. end
  169. end
  170. function BytesDecoder(field_number, is_repeated, is_packed, key, new_default)
  171. local DecodeVarint = _DecodeVarint
  172. local sub = string.sub
  173. assert(not is_packed)
  174. if is_repeated then
  175. local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
  176. local tag_len = #tag_bytes
  177. return function (buffer, pos, pend, message, field_dict)
  178. local value = field_dict[key]
  179. if value == nil then
  180. value = new_default(message)
  181. field_dict[key] = value
  182. end
  183. while 1 do
  184. local size, new_pos
  185. size, pos = DecodeVarint(buffer, pos)
  186. new_pos = pos + size
  187. if new_pos > pend then
  188. error('Truncated string.')
  189. end
  190. value:append(sub(buffer, pos + 1, new_pos))
  191. pos = new_pos + tag_len
  192. if sub(buffer, new_pos + 1, pos) ~= tag_bytes or new_pos == pend then
  193. return new_pos
  194. end
  195. end
  196. end
  197. else
  198. return function(buffer, pos, pend, message, field_dict)
  199. local size, new_pos
  200. size, pos = DecodeVarint(buffer, pos)
  201. new_pos = pos + size
  202. if new_pos > pend then
  203. error('Truncated string.')
  204. end
  205. field_dict[key] = sub(buffer, pos + 1, new_pos)
  206. return new_pos
  207. end
  208. end
  209. end
  210. function MessageDecoder(field_number, is_repeated, is_packed, key, new_default)
  211. local DecodeVarint = _DecodeVarint
  212. local sub = string.sub
  213. assert(not is_packed)
  214. if is_repeated then
  215. local tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
  216. local tag_len = #tag_bytes
  217. return function (buffer, pos, pend, message, field_dict)
  218. local value = field_dict[key]
  219. if value == nil then
  220. value = new_default(message)
  221. field_dict[key] = value
  222. end
  223. while 1 do
  224. local size, new_pos
  225. size, pos = DecodeVarint(buffer, pos)
  226. new_pos = pos + size
  227. if new_pos > pend then
  228. error('Truncated message.')
  229. end
  230. if value:add():_InternalParse(buffer, pos, new_pos) ~= new_pos then
  231. error('Unexpected end-group tag.')
  232. end
  233. pos = new_pos + tag_len
  234. if sub(buffer, new_pos + 1, pos) ~= tag_bytes or new_pos == pend then
  235. return new_pos
  236. end
  237. end
  238. end
  239. else
  240. return function (buffer, pos, pend, message, field_dict)
  241. local value = field_dict[key]
  242. if value == nil then
  243. value = new_default(message)
  244. field_dict[key] = value
  245. end
  246. local size, new_pos
  247. size, pos = DecodeVarint(buffer, pos)
  248. new_pos = pos + size
  249. if new_pos > pend then
  250. error('Truncated message.')
  251. end
  252. if value:_InternalParse(buffer, pos, new_pos) ~= new_pos then
  253. error('Unexpected end-group tag.')
  254. end
  255. return new_pos
  256. end
  257. end
  258. end
  259. function _SkipVarint(buffer, pos, pend)
  260. local value
  261. value, pos = _DecodeVarint(buffer, pos)
  262. return pos
  263. end
  264. function _SkipFixed64(buffer, pos, pend)
  265. pos = pos + 8
  266. if pos > pend then
  267. error('Truncated message.')
  268. end
  269. return pos
  270. end
  271. function _SkipLengthDelimited(buffer, pos, pend)
  272. local size
  273. size, pos = _DecodeVarint(buffer, pos)
  274. pos = pos + size
  275. if pos > pend then
  276. error('Truncated message.')
  277. end
  278. return pos
  279. end
  280. function _SkipFixed32(buffer, pos, pend)
  281. pos = pos + 4
  282. if pos > pend then
  283. error('Truncated message.')
  284. end
  285. return pos
  286. end
  287. function _RaiseInvalidWireType(buffer, pos, pend)
  288. error('Tag had invalid wire type.')
  289. end
  290. function _FieldSkipper()
  291. WIRETYPE_TO_SKIPPER = {
  292. _SkipVarint,
  293. _SkipFixed64,
  294. _SkipLengthDelimited,
  295. _SkipGroup,
  296. _EndGroup,
  297. _SkipFixed32,
  298. _RaiseInvalidWireType,
  299. _RaiseInvalidWireType,
  300. }
  301. -- wiretype_mask = wire_format.TAG_TYPE_MASK
  302. local ord = string.byte
  303. local sub = string.sub
  304. return function (buffer, pos, pend, tag_bytes)
  305. local wire_type = ord(sub(tag_bytes, 1, 1)) % 8 + 1
  306. return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, pend)
  307. end
  308. end
  309. SkipField = _FieldSkipper()