msg.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package rpc
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "reflect"
  8. "rocommon"
  9. "rocommon/util"
  10. )
  11. const (
  12. lenMaxLen = 2 //包体大小2个字节 uint16
  13. msgIdLen = 2 //包ID大小2个字节 uint16
  14. msgSeqlen = 4 //发送序列号2个字节大小,用来断线重连
  15. msgFlaglen = 2 //暂定标记,加解密 1表示RSA,2表示AES
  16. SC_HAND_SHAKE_NTFMsgId = 1006
  17. SC_HAND_SHAKE_ACKMsgId = 0
  18. CS_HAND_SHAKE_REQMsgId = 0
  19. SC_PING_ACKMsgId = 1001
  20. )
  21. //var SC_HAND_SHAKE_NTFMsgId = MessageInfoByName("SCHandShakeNtf").ID
  22. // /////////////////////
  23. func ReadMessage(reader io.Reader, maxMsgLen int, aesKey *[]byte) (msg interface{}, msgSeqId uint32, err error) {
  24. var msgId, flagId uint16 = 0, 0
  25. var msgData []byte
  26. msgId, msgSeqId, flagId, msgData, err = RecvPackageData(reader, maxMsgLen)
  27. if err != nil {
  28. return nil, 0, err
  29. }
  30. switch flagId {
  31. case 1:
  32. if int(msgId) == SC_HAND_SHAKE_NTFMsgId { //SC_HAND_SHAKE_NTF
  33. msgData, err = RSADecrypt(msgData, PrivateClientKey)
  34. if err != nil {
  35. return nil, 0, err
  36. }
  37. } else if int(msgId) == CS_HAND_SHAKE_REQMsgId { //CS_HAND_SHAKE_REQ
  38. msgData, err = RSADecrypt(msgData, PrivateServerKey)
  39. if err != nil {
  40. return nil, 0, err
  41. }
  42. } else if int(msgId) == SC_HAND_SHAKE_ACKMsgId { //SC_HAND_SHAKE_ACK
  43. msgData, err = RSADecrypt(msgData, PrivateClientKey)
  44. if err != nil {
  45. return nil, 0, err
  46. }
  47. } else {
  48. msgData, err = RSADecrypt(msgData, PrivateKey)
  49. if err != nil {
  50. return nil, 0, err
  51. }
  52. }
  53. case 2:
  54. msgData, err = AESCtrDecrypt(msgData, *aesKey, *aesKey...)
  55. //msgData, err = AESCtrDecrypt(msgData, *aesKey)
  56. if err != nil {
  57. return nil, 0, err
  58. }
  59. }
  60. //服务器内部不做加密处理
  61. msg, _, err = DecodeMessage(int(msgId), msgData)
  62. if err != nil {
  63. //log.Println("[DecodeMessage] err:", err)
  64. return nil, 0, errors.New(fmt.Sprintf("msg decodeMessage failed:%v %v", msgId, err))
  65. }
  66. /*
  67. bufMsgLen := make([]byte, lenMaxLen)
  68. _, err = io.ReadFull(reader, bufMsgLen)
  69. if err != nil {
  70. //log.Println("[ReadMessage] read message err:", err)
  71. return
  72. }
  73. if len(bufMsgLen) < lenMaxLen {
  74. err = errors.New("message too short")
  75. return
  76. }
  77. msgLen := binary.BigEndian.Uint16(bufMsgLen)
  78. if(msgLen > 0 && msgLen > uint16(maxMsgLen)) || msgLen < lenMaxLen{
  79. err = errors.New(fmt.Sprintf("message too big33:%v %v\n",msgLen, maxMsgLen))
  80. return
  81. }
  82. msgData := make([]byte, msgLen - lenMaxLen)
  83. if _, err = io.ReadFull(reader, msgData); err != nil {
  84. //log.Println("[ReadMessage] read message err:", err)
  85. return
  86. }
  87. if len(msgData) < msgIdLen{
  88. return nil, 0, errors.New("message id too short")
  89. }
  90. msgId := binary.BigEndian.Uint16(msgData)
  91. msgData = msgData[msgIdLen:]
  92. msgSeqId = binary.BigEndian.Uint32(msgData) //序列号
  93. //log.Println("readSeqId:", msgSeqId)
  94. body := msgData[msgSeqlen:]
  95. msg, _, err = DecodeMessage(int(msgId), body)
  96. if err != nil {
  97. //log.Println("[DecodeMessage] err:", err)
  98. return nil, 0, errors.New(fmt.Sprintf("msg decodeMessage failed:%v %v",msgId, err))
  99. }
  100. */
  101. return
  102. }
  103. // 消息反序列化
  104. func DecodeMessage(id int, data []byte) (interface{}, *rocommon.MessageInfo, error) {
  105. msgInfo := rocommon.MessageInfoByID(id)
  106. if msgInfo == nil {
  107. return nil, nil, errors.New("msg not register")
  108. }
  109. msg := reflect.New(msgInfo.Type).Interface()
  110. //解码操作这边直接用protobuf即可
  111. err := msgInfo.Codec.Unmarshal(data, msg)
  112. if err != nil {
  113. return nil, msgInfo, err
  114. }
  115. return msg, msgInfo, nil
  116. }
  117. func SendMessage(writer io.Writer, msg interface{}, aesKey *[]byte, maxMsgLen int, nodeName string) (err error) {
  118. var (
  119. msgData []byte
  120. msgId uint16
  121. seqId uint32
  122. msgInfo *rocommon.MessageInfo
  123. )
  124. switch m := msg.(type) {
  125. case *rocommon.TransmitPacket:
  126. msgData = m.MsgData
  127. msgId = uint16(m.MsgId)
  128. seqId = m.SeqId
  129. default:
  130. msgData, msgInfo, err = EncodeMessage(msg)
  131. if err != nil {
  132. return err
  133. }
  134. msgId = uint16(msgInfo.ID)
  135. }
  136. if msgId != 1001 && msgId != 15 && msgId != 3 {
  137. util.InfoF("SendMessage msgId=%v", msgId)
  138. //log.Printf("SendMessage msgData=%v", string(msgData))
  139. }
  140. //todo
  141. // 注意上层发包不要超过最大值
  142. msgLen := len(msgData)
  143. var cryptType uint16 = 0
  144. //握手阶段
  145. if msgId == uint16(SC_HAND_SHAKE_NTFMsgId) {
  146. cryptType = 1
  147. msgData, err = RSAEncrypt(msgData, PublicClientKey)
  148. if err != nil {
  149. return err
  150. }
  151. msgLen = len(msgData)
  152. } else {
  153. if len(*aesKey) > 0 && msgId != SC_PING_ACKMsgId {
  154. cryptType = 2
  155. msgData, err = AESCtrEncrypt(msgData, *aesKey, *aesKey...)
  156. //msgData, err = AESCtrEncrypt(msgData, *aesKey)
  157. if err != nil {
  158. return err
  159. }
  160. msgLen = len(msgData)
  161. }
  162. }
  163. if msgLen > maxMsgLen {
  164. err = errors.New(fmt.Sprintf("message too big msgId=%v msglen=%v maxlen=%v", msgId, msgLen, maxMsgLen))
  165. util.FatalF("SendMessage err=%v", err)
  166. err = nil
  167. return
  168. }
  169. //data := make([]byte, lenMaxLen + msgIdLen + msgLen)
  170. data := make([]byte, lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen+msgLen) //head + body
  171. //lenMaxLen
  172. binary.BigEndian.PutUint16(data, uint16(msgLen))
  173. //msgIdLen
  174. binary.BigEndian.PutUint16(data[lenMaxLen:], msgId)
  175. //seq 返回客户端发送的序列号
  176. binary.BigEndian.PutUint32(data[lenMaxLen+msgIdLen:], seqId)
  177. //log.Println("sendSeqId:", seqId)
  178. //使用的加密方式AES
  179. binary.BigEndian.PutUint16(data[lenMaxLen+msgIdLen+msgSeqlen:], cryptType)
  180. //body
  181. if msgLen > 0 {
  182. copy(data[lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen:], msgData)
  183. }
  184. //ioutil.go
  185. err = util.WriteFull(writer, data)
  186. //todo...使用内存池是否data数据
  187. return err
  188. }
  189. // 消息序列化
  190. func EncodeMessage(msg interface{}) (data []byte, info *rocommon.MessageInfo, err error) {
  191. info = rocommon.MessageInfoByMsg(msg)
  192. if info == nil {
  193. return nil, nil, errors.New("msg not register")
  194. }
  195. //log.Println("EncodeMessage:", msg)
  196. tempData, e := info.Codec.Marshal(msg)
  197. data = tempData.([]byte)
  198. err = e
  199. return
  200. }
  201. // 获取原始包数据(二进制),不做解析处理
  202. func RecvPackageData(reader io.Reader, maxMsgLen int) (msgId uint16, msgSeqId uint32, msgFlagId uint16, msgData []byte, err error) {
  203. bufMsgLen := make([]byte, lenMaxLen)
  204. _, err = io.ReadFull(reader, bufMsgLen)
  205. if err != nil {
  206. //log.Println("[ReadMessage] read message err:", err)
  207. return
  208. }
  209. if len(bufMsgLen) < lenMaxLen {
  210. //err = errors.New("message too short")
  211. return
  212. }
  213. //msgId
  214. bufIdLen := make([]byte, msgIdLen)
  215. _, err = io.ReadFull(reader, bufIdLen)
  216. if err != nil {
  217. //log.Println("[ReadMessage] read message err:", err)
  218. return
  219. }
  220. if len(bufIdLen) < msgIdLen {
  221. //err = errors.New("message too short")
  222. return
  223. }
  224. msgId = binary.BigEndian.Uint16(bufIdLen)
  225. //msgseqid
  226. bufSeqIdLen := make([]byte, msgSeqlen)
  227. _, err = io.ReadFull(reader, bufSeqIdLen)
  228. if err != nil {
  229. //log.Println("[ReadMessage] read message err:", err)
  230. return
  231. }
  232. if len(bufSeqIdLen) < msgSeqlen {
  233. //err = errors.New("message too short")
  234. return
  235. }
  236. msgSeqId = binary.BigEndian.Uint32(bufSeqIdLen)
  237. //msgFlaglen 1表示RSA,2表示AES
  238. bufFlagLen := make([]byte, msgFlaglen)
  239. _, err = io.ReadFull(reader, bufFlagLen)
  240. if err != nil {
  241. return
  242. }
  243. if len(bufFlagLen) < msgFlaglen {
  244. return
  245. }
  246. msgFlagId = binary.BigEndian.Uint16(bufFlagLen)
  247. //BigEndian
  248. msgLen := binary.BigEndian.Uint16(bufMsgLen)
  249. if msgLen > 0 && msgLen > uint16(maxMsgLen) {
  250. //err = errors.New("message too big")
  251. err = errors.New(fmt.Sprintf("message too big msgid=%v mslen=%v maxlen=%v bufMsgLen=%v msgFlagId=%v\n",
  252. msgId, msgLen, maxMsgLen, len(bufMsgLen), msgFlagId))
  253. util.FatalF("RecvPackageData err=%v", err)
  254. err = nil
  255. return
  256. }
  257. //todo 可以使用内存池
  258. if msgLen > 0 {
  259. //body := make([]byte, msgLen)
  260. //if _, err = io.ReadFull(reader, body); err != nil {
  261. // //log.Println("[ReadMessage] read message err:", err)
  262. // return
  263. //}
  264. //if len(body) < int(msgLen) {
  265. // err = errors.New(fmt.Sprintf("message id too short msgid=%v", msgId))
  266. // return
  267. //}
  268. //
  269. ////msgId = binary.BigEndian.Uint16(body)
  270. ////body = body[msgIdLen:]
  271. ////msgSeqId = binary.BigEndian.Uint32(body) //序列号
  272. ////log.Println("readSeqId:", msgSeqId)
  273. ////msgData = body[msgSeqlen:]
  274. //msgData = body
  275. msgData = make([]byte, msgLen)
  276. if _, err = io.ReadFull(reader, msgData); err != nil {
  277. //log.Println("[ReadMessage] read message err:", err)
  278. return
  279. }
  280. if len(msgData) < int(msgLen) {
  281. err = errors.New(fmt.Sprintf("message id too short msgid=%v", msgId))
  282. return
  283. }
  284. }
  285. return
  286. }