| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- package rpc
- import (
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "reflect"
- "rocommon"
- "rocommon/util"
- )
- const (
- lenMaxLen = 2 //包体大小2个字节 uint16
- msgIdLen = 2 //包ID大小2个字节 uint16
- msgSeqlen = 4 //发送序列号2个字节大小,用来断线重连
- msgFlaglen = 2 //暂定标记,加解密 1表示RSA,2表示AES
- SC_HAND_SHAKE_NTFMsgId = 1006
- SC_HAND_SHAKE_ACKMsgId = 0
- CS_HAND_SHAKE_REQMsgId = 0
- SC_PING_ACKMsgId = 1001
- )
- //var SC_HAND_SHAKE_NTFMsgId = MessageInfoByName("SCHandShakeNtf").ID
- // /////////////////////
- func ReadMessage(reader io.Reader, maxMsgLen int, aesKey *[]byte) (msg interface{}, msgSeqId uint32, err error) {
- var msgId, flagId uint16 = 0, 0
- var msgData []byte
- msgId, msgSeqId, flagId, msgData, err = RecvPackageData(reader, maxMsgLen)
- if err != nil {
- return nil, 0, err
- }
- switch flagId {
- case 1:
- if int(msgId) == SC_HAND_SHAKE_NTFMsgId { //SC_HAND_SHAKE_NTF
- msgData, err = RSADecrypt(msgData, PrivateClientKey)
- if err != nil {
- return nil, 0, err
- }
- } else if int(msgId) == CS_HAND_SHAKE_REQMsgId { //CS_HAND_SHAKE_REQ
- msgData, err = RSADecrypt(msgData, PrivateServerKey)
- if err != nil {
- return nil, 0, err
- }
- } else if int(msgId) == SC_HAND_SHAKE_ACKMsgId { //SC_HAND_SHAKE_ACK
- msgData, err = RSADecrypt(msgData, PrivateClientKey)
- if err != nil {
- return nil, 0, err
- }
- } else {
- msgData, err = RSADecrypt(msgData, PrivateKey)
- if err != nil {
- return nil, 0, err
- }
- }
- case 2:
- msgData, err = AESCtrDecrypt(msgData, *aesKey, *aesKey...)
- //msgData, err = AESCtrDecrypt(msgData, *aesKey)
- if err != nil {
- return nil, 0, err
- }
- }
- //服务器内部不做加密处理
- msg, _, err = DecodeMessage(int(msgId), msgData)
- if err != nil {
- //log.Println("[DecodeMessage] err:", err)
- return nil, 0, errors.New(fmt.Sprintf("msg decodeMessage failed:%v %v", msgId, err))
- }
- /*
- bufMsgLen := make([]byte, lenMaxLen)
- _, err = io.ReadFull(reader, bufMsgLen)
- if err != nil {
- //log.Println("[ReadMessage] read message err:", err)
- return
- }
- if len(bufMsgLen) < lenMaxLen {
- err = errors.New("message too short")
- return
- }
- msgLen := binary.BigEndian.Uint16(bufMsgLen)
- if(msgLen > 0 && msgLen > uint16(maxMsgLen)) || msgLen < lenMaxLen{
- err = errors.New(fmt.Sprintf("message too big33:%v %v\n",msgLen, maxMsgLen))
- return
- }
- msgData := make([]byte, msgLen - lenMaxLen)
- if _, err = io.ReadFull(reader, msgData); err != nil {
- //log.Println("[ReadMessage] read message err:", err)
- return
- }
- if len(msgData) < msgIdLen{
- return nil, 0, errors.New("message id too short")
- }
- msgId := binary.BigEndian.Uint16(msgData)
- msgData = msgData[msgIdLen:]
- msgSeqId = binary.BigEndian.Uint32(msgData) //序列号
- //log.Println("readSeqId:", msgSeqId)
- body := msgData[msgSeqlen:]
- msg, _, err = DecodeMessage(int(msgId), body)
- if err != nil {
- //log.Println("[DecodeMessage] err:", err)
- return nil, 0, errors.New(fmt.Sprintf("msg decodeMessage failed:%v %v",msgId, err))
- }
- */
- return
- }
- // 消息反序列化
- func DecodeMessage(id int, data []byte) (interface{}, *rocommon.MessageInfo, error) {
- msgInfo := rocommon.MessageInfoByID(id)
- if msgInfo == nil {
- return nil, nil, errors.New("msg not register")
- }
- msg := reflect.New(msgInfo.Type).Interface()
- //解码操作这边直接用protobuf即可
- err := msgInfo.Codec.Unmarshal(data, msg)
- if err != nil {
- return nil, msgInfo, err
- }
- return msg, msgInfo, nil
- }
- func SendMessage(writer io.Writer, msg interface{}, aesKey *[]byte, maxMsgLen int, nodeName string) (err error) {
- var (
- msgData []byte
- msgId uint16
- seqId uint32
- msgInfo *rocommon.MessageInfo
- )
- switch m := msg.(type) {
- case *rocommon.TransmitPacket:
- msgData = m.MsgData
- msgId = uint16(m.MsgId)
- seqId = m.SeqId
- default:
- msgData, msgInfo, err = EncodeMessage(msg)
- if err != nil {
- return err
- }
- msgId = uint16(msgInfo.ID)
- }
- if msgId != 1001 && msgId != 15 && msgId != 3 {
- util.InfoF("SendMessage msgId=%v", msgId)
- //log.Printf("SendMessage msgData=%v", string(msgData))
- }
- //todo
- // 注意上层发包不要超过最大值
- msgLen := len(msgData)
- var cryptType uint16 = 0
- //握手阶段
- if msgId == uint16(SC_HAND_SHAKE_NTFMsgId) {
- cryptType = 1
- msgData, err = RSAEncrypt(msgData, PublicClientKey)
- if err != nil {
- return err
- }
- msgLen = len(msgData)
- } else {
- if len(*aesKey) > 0 && msgId != SC_PING_ACKMsgId {
- cryptType = 2
- msgData, err = AESCtrEncrypt(msgData, *aesKey, *aesKey...)
- //msgData, err = AESCtrEncrypt(msgData, *aesKey)
- if err != nil {
- return err
- }
- msgLen = len(msgData)
- }
- }
- if msgLen > maxMsgLen {
- err = errors.New(fmt.Sprintf("message too big msgId=%v msglen=%v maxlen=%v", msgId, msgLen, maxMsgLen))
- util.FatalF("SendMessage err=%v", err)
- err = nil
- return
- }
- //data := make([]byte, lenMaxLen + msgIdLen + msgLen)
- data := make([]byte, lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen+msgLen) //head + body
- //lenMaxLen
- binary.BigEndian.PutUint16(data, uint16(msgLen))
- //msgIdLen
- binary.BigEndian.PutUint16(data[lenMaxLen:], msgId)
- //seq 返回客户端发送的序列号
- binary.BigEndian.PutUint32(data[lenMaxLen+msgIdLen:], seqId)
- //log.Println("sendSeqId:", seqId)
- //使用的加密方式AES
- binary.BigEndian.PutUint16(data[lenMaxLen+msgIdLen+msgSeqlen:], cryptType)
- //body
- if msgLen > 0 {
- copy(data[lenMaxLen+msgIdLen+msgSeqlen+msgFlaglen:], msgData)
- }
- //ioutil.go
- err = util.WriteFull(writer, data)
- //todo...使用内存池是否data数据
- return err
- }
- // 消息序列化
- func EncodeMessage(msg interface{}) (data []byte, info *rocommon.MessageInfo, err error) {
- info = rocommon.MessageInfoByMsg(msg)
- if info == nil {
- return nil, nil, errors.New("msg not register")
- }
- //log.Println("EncodeMessage:", msg)
- tempData, e := info.Codec.Marshal(msg)
- data = tempData.([]byte)
- err = e
- return
- }
- // 获取原始包数据(二进制),不做解析处理
- func RecvPackageData(reader io.Reader, maxMsgLen int) (msgId uint16, msgSeqId uint32, msgFlagId uint16, msgData []byte, err error) {
- bufMsgLen := make([]byte, lenMaxLen)
- _, err = io.ReadFull(reader, bufMsgLen)
- if err != nil {
- //log.Println("[ReadMessage] read message err:", err)
- return
- }
- if len(bufMsgLen) < lenMaxLen {
- //err = errors.New("message too short")
- return
- }
- //msgId
- bufIdLen := make([]byte, msgIdLen)
- _, err = io.ReadFull(reader, bufIdLen)
- if err != nil {
- //log.Println("[ReadMessage] read message err:", err)
- return
- }
- if len(bufIdLen) < msgIdLen {
- //err = errors.New("message too short")
- return
- }
- msgId = binary.BigEndian.Uint16(bufIdLen)
- //msgseqid
- bufSeqIdLen := make([]byte, msgSeqlen)
- _, err = io.ReadFull(reader, bufSeqIdLen)
- if err != nil {
- //log.Println("[ReadMessage] read message err:", err)
- return
- }
- if len(bufSeqIdLen) < msgSeqlen {
- //err = errors.New("message too short")
- return
- }
- msgSeqId = binary.BigEndian.Uint32(bufSeqIdLen)
- //msgFlaglen 1表示RSA,2表示AES
- bufFlagLen := make([]byte, msgFlaglen)
- _, err = io.ReadFull(reader, bufFlagLen)
- if err != nil {
- return
- }
- if len(bufFlagLen) < msgFlaglen {
- return
- }
- msgFlagId = binary.BigEndian.Uint16(bufFlagLen)
- //BigEndian
- msgLen := binary.BigEndian.Uint16(bufMsgLen)
- if msgLen > 0 && msgLen > uint16(maxMsgLen) {
- //err = errors.New("message too big")
- err = errors.New(fmt.Sprintf("message too big msgid=%v mslen=%v maxlen=%v bufMsgLen=%v msgFlagId=%v\n",
- msgId, msgLen, maxMsgLen, len(bufMsgLen), msgFlagId))
- util.FatalF("RecvPackageData err=%v", err)
- err = nil
- return
- }
- //todo 可以使用内存池
- if msgLen > 0 {
- //body := make([]byte, msgLen)
- //if _, err = io.ReadFull(reader, body); err != nil {
- // //log.Println("[ReadMessage] read message err:", err)
- // return
- //}
- //if len(body) < int(msgLen) {
- // err = errors.New(fmt.Sprintf("message id too short msgid=%v", msgId))
- // return
- //}
- //
- ////msgId = binary.BigEndian.Uint16(body)
- ////body = body[msgIdLen:]
- ////msgSeqId = binary.BigEndian.Uint32(body) //序列号
- ////log.Println("readSeqId:", msgSeqId)
- ////msgData = body[msgSeqlen:]
- //msgData = body
- msgData = make([]byte, msgLen)
- if _, err = io.ReadFull(reader, msgData); err != nil {
- //log.Println("[ReadMessage] read message err:", err)
- return
- }
- if len(msgData) < int(msgLen) {
- err = errors.New(fmt.Sprintf("message id too short msgid=%v", msgId))
- return
- }
- }
- return
- }
|