msg.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package rpc
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "reflect"
  8. "rocommon"
  9. "rocommon/util"
  10. )
  11. const (
  12. lenMaxLen = 2 //包体大小2个字节 uint16
  13. msgIdLen = 2 //包ID大小2个字节 uint16
  14. msgSeqlen = 4 //发送序列号2个字节大小,用来断线重连
  15. msgFlaglen = 2 //暂定标记,加解密 1表示RSA,2表示AES
  16. SC_HAND_SHAKE_NTFMsgId = 1006
  17. SC_HAND_SHAKE_ACKMsgId = 0
  18. CS_HAND_SHAKE_REQMsgId = 0
  19. SC_PING_ACKMsgId = 1001
  20. )
  21. //var SC_HAND_SHAKE_NTFMsgId = MessageInfoByName("SCHandShakeNtf").ID
  22. ///////////////////////
  23. func ReadMessage(reader io.Reader, maxMsgLen int, aesKey *[]byte) (msg interface{}, msgSeqId uint32, err error) {
  24. var msgId, flagId uint16 = 0, 0
  25. var msgData []byte
  26. msgId, msgSeqId, flagId, msgData, err = RecvPackageData(reader, maxMsgLen)
  27. if err != nil {
  28. return nil, 0, err
  29. }
  30. switch flagId {
  31. case 1:
  32. if int(msgId) == SC_HAND_SHAKE_NTFMsgId { //SC_HAND_SHAKE_NTF
  33. msgData, err = RSADecrypt(msgData, PrivateClientKey)
  34. if err != nil {
  35. return nil, 0, err
  36. }
  37. } else if int(msgId) == CS_HAND_SHAKE_REQMsgId { //CS_HAND_SHAKE_REQ
  38. msgData, err = RSADecrypt(msgData, PrivateServerKey)
  39. if err != nil {
  40. return nil, 0, err
  41. }
  42. } else if int(msgId) == SC_HAND_SHAKE_ACKMsgId { //SC_HAND_SHAKE_ACK
  43. msgData, err = RSADecrypt(msgData, PrivateClientKey)
  44. if err != nil {
  45. return nil, 0, err
  46. }
  47. } else {
  48. msgData, err = RSADecrypt(msgData, PrivateKey)
  49. if err != nil {
  50. return nil, 0, err
  51. }
  52. }
  53. case 2:
  54. msgData, err = AESCtrDecrypt(msgData, *aesKey, *aesKey...)
  55. //msgData, err = AESCtrDecrypt(msgData, *aesKey)
  56. if err != nil {
  57. return nil, 0, err
  58. }
  59. }
  60. //服务器内部不做加密处理
  61. msg, _, err = DecodeMessage(int(msgId), msgData)
  62. if err != nil {
  63. //log.Println("[DecodeMessage] err:", err)
  64. return nil, 0, errors.New(fmt.Sprintf("msg decodeMessage failed:%v %v", msgId, err))
  65. }
  66. /*
  67. bufMsgLen := make([]byte, lenMaxLen)
  68. _, err = io.ReadFull(reader, bufMsgLen)
  69. if err != nil {
  70. //log.Println("[ReadMessage] read message err:", err)
  71. return
  72. }
  73. if len(bufMsgLen) < lenMaxLen {
  74. err = errors.New("message too short")
  75. return
  76. }
  77. msgLen := binary.BigEndian.Uint16(bufMsgLen)
  78. if(msgLen > 0 && msgLen > uint16(maxMsgLen)) || msgLen < lenMaxLen{
  79. err = errors.New(fmt.Sprintf("message too big33:%v %v\n",msgLen, maxMsgLen))
  80. return
  81. }
  82. msgData := make([]byte, msgLen - lenMaxLen)
  83. if _, err = io.ReadFull(reader, msgData); err != nil {
  84. //log.Println("[ReadMessage] read message err:", err)
  85. return
  86. }
  87. if len(msgData) < msgIdLen{
  88. return nil, 0, errors.New("message id too short")
  89. }
  90. msgId := binary.BigEndian.Uint16(msgData)
  91. msgData = msgData[msgIdLen:]
  92. msgSeqId = binary.BigEndian.Uint32(msgData) //序列号
  93. //log.Println("readSeqId:", msgSeqId)
  94. body := msgData[msgSeqlen:]
  95. msg, _, err = DecodeMessage(int(msgId), body)
  96. if err != nil {
  97. //log.Println("[DecodeMessage] err:", err)
  98. return nil, 0, errors.New(fmt.Sprintf("msg decodeMessage failed:%v %v",msgId, err))
  99. }
  100. */
  101. return
  102. }
  103. //消息反序列化
  104. func DecodeMessage(id int, data []byte) (interface{}, *rocommon.MessageInfo, error) {
  105. msgInfo := rocommon.MessageInfoByID(id)
  106. if msgInfo == nil {
  107. return nil, nil, errors.New("msg not register")
  108. }
  109. msg := reflect.New(msgInfo.Type).Interface()
  110. //解码操作这边直接用protobuf即可
  111. err := msgInfo.Codec.Unmarshal(data, msg)
  112. if err != nil {
  113. return nil, msgInfo, err
  114. }
  115. return msg, msgInfo, nil
  116. }
  117. func SendMessage(writer io.Writer, msg interface{}, aesKey *[]byte, maxMsgLen int, nodeName string) (err error) {
  118. var (
  119. msgData []byte
  120. msgId uint16
  121. seqId uint32
  122. msgInfo *rocommon.MessageInfo
  123. )
  124. switch m := msg.(type) {
  125. case *rocommon.TransmitPacket:
  126. msgData = m.MsgData
  127. msgId = uint16(m.MsgId)
  128. seqId = m.SeqId
  129. default:
  130. msgData, msgInfo, err = EncodeMessage(msg)
  131. if err != nil {
  132. return err
  133. }
  134. msgId = uint16(msgInfo.ID)
  135. }
  136. //todo
  137. // 注意上层发包不要超过最大值
  138. msgLen := len(msgData)
  139. var cryptType uint16 = 0
  140. //握手阶段
  141. if msgId == uint16(SC_HAND_SHAKE_NTFMsgId) {
  142. cryptType = 1
  143. msgData, err = RSAEncrypt(msgData, PublicClientKey)
  144. if err != nil {
  145. return err
  146. }
  147. msgLen = len(msgData)
  148. } else {
  149. if len(*aesKey) > 0 && msgId != SC_PING_ACKMsgId {
  150. cryptType = 2
  151. msgData, err = AESCtrEncrypt(msgData, *aesKey, *aesKey...)
  152. //msgData, err = AESCtrEncrypt(msgData, *aesKey)
  153. if err != nil {
  154. return err
  155. }
  156. msgLen = len(msgData)
  157. }
  158. }
  159. if msgLen > maxMsgLen {
  160. err = errors.New(fmt.Sprintf("message too big msgId=%v msglen=%v maxlen=%v", msgId, msgLen, maxMsgLen))
  161. util.FatalF("SendMessage err=%v", err)
  162. err = nil
  163. return
  164. }
  165. //data := make([]byte, lenMaxLen + msgIdLen + msgLen)
  166. data := make([]byte, lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen+msgLen) //head + body
  167. //lenMaxLen
  168. binary.BigEndian.PutUint16(data, uint16(msgLen))
  169. //msgIdLen
  170. binary.BigEndian.PutUint16(data[lenMaxLen:], msgId)
  171. //seq 返回客户端发送的序列号
  172. binary.BigEndian.PutUint32(data[lenMaxLen+msgIdLen:], seqId)
  173. //log.Println("sendSeqId:", seqId)
  174. //使用的加密方式AES
  175. binary.BigEndian.PutUint16(data[lenMaxLen+msgIdLen+msgSeqlen:], cryptType)
  176. //body
  177. if msgLen > 0 {
  178. copy(data[lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen:], msgData)
  179. }
  180. //ioutil.go
  181. err = util.WriteFull(writer, data)
  182. //todo...使用内存池是否data数据
  183. return err
  184. }
  185. //消息序列化
  186. func EncodeMessage(msg interface{}) (data []byte, info *rocommon.MessageInfo, err error) {
  187. info = rocommon.MessageInfoByMsg(msg)
  188. if info == nil {
  189. return nil, nil, errors.New("msg not register")
  190. }
  191. //log.Println("EncodeMessage:", msg)
  192. tempData, e := info.Codec.Marshal(msg)
  193. data = tempData.([]byte)
  194. err = e
  195. return
  196. }
  197. //获取原始包数据(二进制),不做解析处理
  198. func RecvPackageData(reader io.Reader, maxMsgLen int) (msgId uint16, msgSeqId uint32, msgFlagId uint16, msgData []byte, err error) {
  199. bufMsgLen := make([]byte, lenMaxLen)
  200. _, err = io.ReadFull(reader, bufMsgLen)
  201. if err != nil {
  202. //log.Println("[ReadMessage] read message err:", err)
  203. return
  204. }
  205. if len(bufMsgLen) < lenMaxLen {
  206. //err = errors.New("message too short")
  207. return
  208. }
  209. //msgId
  210. bufIdLen := make([]byte, msgIdLen)
  211. _, err = io.ReadFull(reader, bufIdLen)
  212. if err != nil {
  213. //log.Println("[ReadMessage] read message err:", err)
  214. return
  215. }
  216. if len(bufIdLen) < msgIdLen {
  217. //err = errors.New("message too short")
  218. return
  219. }
  220. msgId = binary.BigEndian.Uint16(bufIdLen)
  221. //msgseqid
  222. bufSeqIdLen := make([]byte, msgSeqlen)
  223. _, err = io.ReadFull(reader, bufSeqIdLen)
  224. if err != nil {
  225. //log.Println("[ReadMessage] read message err:", err)
  226. return
  227. }
  228. if len(bufSeqIdLen) < msgSeqlen {
  229. //err = errors.New("message too short")
  230. return
  231. }
  232. msgSeqId = binary.BigEndian.Uint32(bufSeqIdLen)
  233. //msgFlaglen 1表示RSA,2表示AES
  234. bufFlagLen := make([]byte, msgFlaglen)
  235. _, err = io.ReadFull(reader, bufFlagLen)
  236. if err != nil {
  237. return
  238. }
  239. if len(bufFlagLen) < msgFlaglen {
  240. return
  241. }
  242. msgFlagId = binary.BigEndian.Uint16(bufFlagLen)
  243. //BigEndian
  244. msgLen := binary.BigEndian.Uint16(bufMsgLen)
  245. if msgLen > 0 && msgLen > uint16(maxMsgLen) {
  246. //err = errors.New("message too big")
  247. err = errors.New(fmt.Sprintf("message too big msgid=%v mslen=%v maxlen=%v bufMsgLen=%v msgFlagId=%v\n",
  248. msgId, msgLen, maxMsgLen, len(bufMsgLen), msgFlagId))
  249. util.FatalF("RecvPackageData err=%v", err)
  250. err = nil
  251. return
  252. }
  253. //todo 可以使用内存池
  254. if msgLen > 0 {
  255. //body := make([]byte, msgLen)
  256. //if _, err = io.ReadFull(reader, body); err != nil {
  257. // //log.Println("[ReadMessage] read message err:", err)
  258. // return
  259. //}
  260. //if len(body) < int(msgLen) {
  261. // err = errors.New(fmt.Sprintf("message id too short msgid=%v", msgId))
  262. // return
  263. //}
  264. //
  265. ////msgId = binary.BigEndian.Uint16(body)
  266. ////body = body[msgIdLen:]
  267. ////msgSeqId = binary.BigEndian.Uint32(body) //序列号
  268. ////log.Println("readSeqId:", msgSeqId)
  269. ////msgData = body[msgSeqlen:]
  270. //msgData = body
  271. msgData = make([]byte, msgLen)
  272. if _, err = io.ReadFull(reader, msgData); err != nil {
  273. //log.Println("[ReadMessage] read message err:", err)
  274. return
  275. }
  276. if len(msgData) < int(msgLen) {
  277. err = errors.New(fmt.Sprintf("message id too short msgid=%v", msgId))
  278. return
  279. }
  280. }
  281. return
  282. }