proc_rpc_ws.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package model
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "github.com/gorilla/websocket"
  7. "rocommon"
  8. "rocommon/rpc"
  9. _ "rocommon/service"
  10. "rocommon/socket"
  11. "rocommon/util"
  12. _ "roserver/baseserver/model"
  13. _ "roserver/baseserver/router"
  14. _ "roserver/serverproto"
  15. )
  16. const (
  17. MsgIDLen = 2 // uint16
  18. lenMaxLen = 2 //包体大小2个字节 uint16
  19. msgIdLen = 2 //包ID大小2个字节 uint16
  20. msgSeqlen = 4 //发送序列号2个字节大小,用来断线重连
  21. msgFlaglen = 2 //暂定标记,加解密 1表示RSA,2表示AES
  22. MsgBodyIdx = 2 + 4 + 2 + 2 //->10
  23. )
  24. type DirectWSMessageTransmitter struct {
  25. }
  26. //recv直接原始数据传递到后端 / 返回消息给client
  27. func (this *DirectWSMessageTransmitter) OnRecvMsg(s rocommon.Session) (msg interface{}, seqId uint32, err error) {
  28. conn, ok := s.Raw().(*websocket.Conn)
  29. if !ok || conn == nil {
  30. util.InfoF("[DirectWSMessageTransmitter] OnRecvMsg err")
  31. return nil, 0, nil
  32. }
  33. //opt := s.Node().(socket.SocketOption)
  34. //var opt socket.SocketOption
  35. //if s.GetSessionOptFlag() {
  36. // opt = s.Node().(socket.SocketOption)
  37. //} else {
  38. // opt = s.GetSessionOpt().(socket.SocketOption)
  39. //}
  40. messageType, raw, err1 := conn.ReadMessage()
  41. if err1 != nil {
  42. err = err1
  43. util.InfoF("[DirectWSMessageTransmitter] OnRecvMsg ReadMessage err=%v", err)
  44. return nil, 0, err1
  45. }
  46. switch messageType {
  47. case websocket.BinaryMessage:
  48. var msgId uint16
  49. //var seqId uint32 //包序列号,客户端发送时的序列从1开始
  50. var flagId uint16 //加密方式
  51. var msgData []byte
  52. var msgDataLen uint16
  53. if len(raw) < lenMaxLen {
  54. return nil, 0, nil
  55. }
  56. msgDataLen = binary.BigEndian.Uint16(raw) //msgDataLen
  57. raw = raw[lenMaxLen:]
  58. if msgDataLen >= 0 {
  59. //msgIdLen
  60. if len(raw) < msgIdLen {
  61. return
  62. }
  63. msgId = binary.BigEndian.Uint16(raw)
  64. raw = raw[msgIdLen:]
  65. //msgSeqlen
  66. if len(raw) < msgSeqlen {
  67. return
  68. }
  69. seqId = binary.BigEndian.Uint32(raw)
  70. raw = raw[msgSeqlen:]
  71. //msgFlaglen
  72. if len(raw) < msgFlaglen {
  73. return
  74. }
  75. flagId = binary.BigEndian.Uint16(raw)
  76. msgData = raw[msgFlaglen:]
  77. //尝试直接发送到其他后端服务器或者解析
  78. if err == nil {
  79. msg, err = FrontendPackageProc(int(msgId), seqId, flagId, msgData, s)
  80. }
  81. }
  82. }
  83. return
  84. }
  85. //send 直接发往客户端的消息
  86. func (this *DirectWSMessageTransmitter) OnSendMsg(s rocommon.Session, msg interface{}) (err error) {
  87. conn, ok := s.Raw().(*websocket.Conn)
  88. if !ok || conn == nil {
  89. util.InfoF("[DirectWSMessageTransmitter] OnRecvMsg err")
  90. return nil
  91. }
  92. //opt := s.Node().(socket.SocketOption)
  93. var opt socket.SocketOption
  94. if s.GetSessionOptFlag() {
  95. opt = s.Node().(socket.SocketOption)
  96. } else {
  97. opt = s.GetSessionOpt().(socket.SocketOption)
  98. }
  99. aesKey := s.GetAES()
  100. var (
  101. msgData []byte
  102. msgId uint16
  103. seqId uint32
  104. msgInfo *rocommon.MessageInfo
  105. )
  106. switch m := msg.(type) {
  107. case *rocommon.TransmitPacket:
  108. msgData = m.MsgData
  109. msgId = uint16(m.MsgId)
  110. seqId = m.SeqId
  111. default:
  112. msgData, msgInfo, err = rpc.EncodeMessage(msg)
  113. if err != nil {
  114. return err
  115. }
  116. msgId = uint16(msgInfo.ID)
  117. }
  118. //todo
  119. // 注意上层发包不要超过最大值
  120. msgLen := len(msgData)
  121. var cryptType uint16 = 0
  122. //握手阶段
  123. if msgId == uint16(rpc.SC_HAND_SHAKE_NTFMsgId) {
  124. cryptType = 1
  125. msgData, err = rpc.RSAEncrypt(msgData, rpc.PublicClientKey)
  126. if err != nil {
  127. return err
  128. }
  129. msgLen = len(msgData)
  130. } else {
  131. if len(*aesKey) > 0 && msgId != rpc.SC_PING_ACKMsgId {
  132. cryptType = 2
  133. msgData, err = rpc.AESCtrEncrypt(msgData, *aesKey, *aesKey...)
  134. //msgData, err = AESCtrEncrypt(msgData, *aesKey)
  135. if err != nil {
  136. return err
  137. }
  138. msgLen = len(msgData)
  139. }
  140. }
  141. if msgLen > opt.MaxMsgLen() {
  142. err = errors.New(fmt.Sprintf("message too big msgId=%v msglen=%v maxlen=%v", msgId, msgLen, opt.MaxMsgLen()))
  143. util.FatalF("SendMessage err=%v", err)
  144. err = nil
  145. return
  146. }
  147. //data := make([]byte, lenMaxLen + msgIdLen + msgLen)
  148. data := make([]byte, lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen+msgLen) //head + body
  149. //lenMaxLen
  150. binary.BigEndian.PutUint16(data, uint16(msgLen))
  151. //msgIdLen
  152. binary.BigEndian.PutUint16(data[lenMaxLen:], uint16(msgId))
  153. //seq 返回客户端发送的序列号
  154. binary.BigEndian.PutUint32(data[lenMaxLen+msgIdLen:], seqId)
  155. //log.Println("sendSeqId:", seqId)
  156. //使用的加密方式AES
  157. binary.BigEndian.PutUint16(data[lenMaxLen+msgIdLen+msgSeqlen:], cryptType)
  158. //body
  159. if msgLen > 0 {
  160. copy(data[lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen:], msgData)
  161. }
  162. conn.WriteMessage(websocket.BinaryMessage, data)
  163. return
  164. }