123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- package protobuf
- import (
- "encoding/binary"
- "errors"
- "fmt"
- "leafstalk/log"
- "leafstalk/module/handler"
- "math"
- "reflect"
- "google.golang.org/protobuf/proto"
- // "github.com/golang/protobuf/proto"
- )
- // -------------------------
- // | id | protobuf message |
- // -------------------------
- type Processor struct {
- littleEndian bool
- msgInfo []*MsgInfo
- msgID map[reflect.Type]uint16
- }
- type MsgInfo struct {
- msgType reflect.Type
- msgRouter *handler.Server
- msgHandler MsgHandler
- msgRawHandler MsgHandler
- }
- type MsgHandler func([]interface{})
- type MsgRaw struct {
- msgID uint16
- msgRawData []byte
- }
- func NewProcessor() *Processor {
- p := new(Processor)
- p.littleEndian = false
- p.msgID = make(map[reflect.Type]uint16)
- return p
- }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) SetByteOrder(littleEndian bool) {
- p.littleEndian = littleEndian
- }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) Register(msg proto.Message) uint16 {
- msgType := reflect.TypeOf(msg)
- if msgType == nil || msgType.Kind() != reflect.Ptr {
- log.Fatal("protobuf message pointer required")
- }
- if _, ok := p.msgID[msgType]; ok {
- log.Fatalf("message %s is already registered", msgType)
- }
- if len(p.msgInfo) >= math.MaxUint16 {
- log.Fatalf("too many protobuf messages (max = %v)", math.MaxUint16)
- }
- i := new(MsgInfo)
- i.msgType = msgType
- p.msgInfo = append(p.msgInfo, i)
- id := uint16(len(p.msgInfo) - 1)
- p.msgID[msgType] = id
- return id
- }
- // func (p *Processor) Register2(msg interface{}) uint16 {
- // msgType := reflect.TypeOf(msg)
- // if msgType == nil || msgType.Kind() != reflect.Ptr {
- // log.Fatal("protobuf message pointer required")
- // }
- // if _, ok := p.msgID[msgType]; ok {
- // log.Fatalf("message %s is already registered", msgType)
- // }
- // if len(p.msgInfo) >= math.MaxUint16 {
- // log.Fatalf("too many protobuf messages (max = %v)", math.MaxUint16)
- // }
- // i := new(MsgInfo)
- // i.msgType = msgType
- // p.msgInfo = append(p.msgInfo, i)
- // id := uint16(len(p.msgInfo) - 1)
- // p.msgID[msgType] = id
- // return id
- // }
- // func (p *Processor) IsRegisted(msg proto.Message) (bool, error) {
- // msgType := reflect.TypeOf(msg)
- // if msgType == nil || msgType.Kind() != reflect.Ptr {
- // return false, errors.New("protobuf message pointer required")
- // }
- // if len(p.msgInfo) >= math.MaxUint16 {
- // return false, fmt.Errorf("too many protobuf messages (max = %v)", math.MaxUint16)
- // }
- // if _, ok := p.msgID[msgType]; ok {
- // return true, nil
- // }
- // return false, nil
- // }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) SetRouter(msg proto.Message, msgRouter *handler.Server) {
- msgType := reflect.TypeOf(msg)
- id, ok := p.msgID[msgType]
- if !ok {
- log.Fatalf("message %s not registered", msgType)
- }
- p.msgInfo[id].msgRouter = msgRouter
- }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) SetHandler(msg proto.Message, msgHandler MsgHandler) {
- msgType := reflect.TypeOf(msg)
- id, ok := p.msgID[msgType]
- if !ok {
- log.Fatalf("message %s not registered", msgType)
- }
- p.msgInfo[id].msgHandler = msgHandler
- }
- // It's dangerous to call the method on routing or marshaling (unmarshaling)
- func (p *Processor) SetRawHandler(id uint16, msgRawHandler MsgHandler) {
- if id >= uint16(len(p.msgInfo)) {
- log.Fatalf("message id %v not registered", id)
- }
- p.msgInfo[id].msgRawHandler = msgRawHandler
- }
- // goroutine safe
- func (p *Processor) Route(msg interface{}, userData interface{}) error {
- // raw
- if msgRaw, ok := msg.(MsgRaw); ok {
- if msgRaw.msgID >= uint16(len(p.msgInfo)) {
- return fmt.Errorf("message id %v not registered", msgRaw.msgID)
- }
- i := p.msgInfo[msgRaw.msgID]
- if i.msgRawHandler != nil {
- i.msgRawHandler([]interface{}{msgRaw.msgID, msgRaw.msgRawData, userData})
- }
- return nil
- }
- // protobuf
- msgType := reflect.TypeOf(msg)
- id, ok := p.msgID[msgType]
- if !ok {
- return fmt.Errorf("message %s not registered", msgType)
- }
- i := p.msgInfo[id]
- if i.msgHandler != nil {
- i.msgHandler([]interface{}{msg, userData})
- }
- if i.msgRouter != nil {
- i.msgRouter.Go(msgType, msg, userData)
- }
- return nil
- }
- // goroutine safe
- func (p *Processor) Unmarshal(data []byte) (interface{}, error) {
- if len(data) < 2 {
- return nil, errors.New("protobuf data too short")
- }
- // id
- var id uint16
- if p.littleEndian {
- id = binary.LittleEndian.Uint16(data)
- } else {
- id = binary.BigEndian.Uint16(data)
- }
- if id >= uint16(len(p.msgInfo)) {
- return nil, fmt.Errorf("message id %v not registered", id)
- }
- // msg
- i := p.msgInfo[id]
- if i.msgRawHandler != nil {
- return MsgRaw{id, data[2:]}, nil
- } else {
- msg := reflect.New(i.msgType.Elem()).Interface()
- // return msg, proto.UnmarshalMerge(data[2:], msg.(proto.Message))
- return msg, proto.Unmarshal(data[2:], msg.(proto.Message))
- }
- }
- // goroutine safe
- func (p *Processor) Marshal(msg interface{}) ([][]byte, error) {
- msgType := reflect.TypeOf(msg)
- // id
- _id, ok := p.msgID[msgType]
- if !ok {
- err := fmt.Errorf("message %s not registered", msgType)
- return nil, err
- }
- id := make([]byte, 2)
- if p.littleEndian {
- binary.LittleEndian.PutUint16(id, _id)
- } else {
- binary.BigEndian.PutUint16(id, _id)
- }
- // data
- data, err := proto.Marshal(msg.(proto.Message))
- return [][]byte{id, data}, err
- }
- // goroutine safe
- func (p *Processor) Range(f func(id uint16, t reflect.Type)) {
- for id, i := range p.msgInfo {
- f(uint16(id), i.msgType)
- }
- }
|