package protobuf import ( "encoding/binary" "errors" "fmt" "leafstalk/log" "leafstalk/module/handler" "math" "reflect" "google.golang.org/protobuf/proto" // "github.com/golang/protobuf/proto" ) // ------------------------- // | id | protobuf message | // ------------------------- type Processor struct { littleEndian bool msgInfo []*MsgInfo msgID map[reflect.Type]uint16 } type MsgInfo struct { msgType reflect.Type msgRouter *handler.Server msgHandler MsgHandler msgRawHandler MsgHandler } type MsgHandler func([]interface{}) type MsgRaw struct { msgID uint16 msgRawData []byte } func NewProcessor() *Processor { p := new(Processor) p.littleEndian = false p.msgID = make(map[reflect.Type]uint16) return p } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) SetByteOrder(littleEndian bool) { p.littleEndian = littleEndian } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) Register(msg proto.Message) uint16 { msgType := reflect.TypeOf(msg) if msgType == nil || msgType.Kind() != reflect.Ptr { log.Fatal("protobuf message pointer required") } if _, ok := p.msgID[msgType]; ok { log.Fatalf("message %s is already registered", msgType) } if len(p.msgInfo) >= math.MaxUint16 { log.Fatalf("too many protobuf messages (max = %v)", math.MaxUint16) } i := new(MsgInfo) i.msgType = msgType p.msgInfo = append(p.msgInfo, i) id := uint16(len(p.msgInfo) - 1) p.msgID[msgType] = id return id } // func (p *Processor) Register2(msg interface{}) uint16 { // msgType := reflect.TypeOf(msg) // if msgType == nil || msgType.Kind() != reflect.Ptr { // log.Fatal("protobuf message pointer required") // } // if _, ok := p.msgID[msgType]; ok { // log.Fatalf("message %s is already registered", msgType) // } // if len(p.msgInfo) >= math.MaxUint16 { // log.Fatalf("too many protobuf messages (max = %v)", math.MaxUint16) // } // i := new(MsgInfo) // i.msgType = msgType // p.msgInfo = append(p.msgInfo, i) // id := uint16(len(p.msgInfo) - 1) // p.msgID[msgType] = id // return id // } // func (p *Processor) IsRegisted(msg proto.Message) (bool, error) { // msgType := reflect.TypeOf(msg) // if msgType == nil || msgType.Kind() != reflect.Ptr { // return false, errors.New("protobuf message pointer required") // } // if len(p.msgInfo) >= math.MaxUint16 { // return false, fmt.Errorf("too many protobuf messages (max = %v)", math.MaxUint16) // } // if _, ok := p.msgID[msgType]; ok { // return true, nil // } // return false, nil // } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) SetRouter(msg proto.Message, msgRouter *handler.Server) { msgType := reflect.TypeOf(msg) id, ok := p.msgID[msgType] if !ok { log.Fatalf("message %s not registered", msgType) } p.msgInfo[id].msgRouter = msgRouter } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) SetHandler(msg proto.Message, msgHandler MsgHandler) { msgType := reflect.TypeOf(msg) id, ok := p.msgID[msgType] if !ok { log.Fatalf("message %s not registered", msgType) } p.msgInfo[id].msgHandler = msgHandler } // It's dangerous to call the method on routing or marshaling (unmarshaling) func (p *Processor) SetRawHandler(id uint16, msgRawHandler MsgHandler) { if id >= uint16(len(p.msgInfo)) { log.Fatalf("message id %v not registered", id) } p.msgInfo[id].msgRawHandler = msgRawHandler } // goroutine safe func (p *Processor) Route(msg interface{}, userData interface{}) error { // raw if msgRaw, ok := msg.(MsgRaw); ok { if msgRaw.msgID >= uint16(len(p.msgInfo)) { return fmt.Errorf("message id %v not registered", msgRaw.msgID) } i := p.msgInfo[msgRaw.msgID] if i.msgRawHandler != nil { i.msgRawHandler([]interface{}{msgRaw.msgID, msgRaw.msgRawData, userData}) } return nil } // protobuf msgType := reflect.TypeOf(msg) id, ok := p.msgID[msgType] if !ok { return fmt.Errorf("message %s not registered", msgType) } i := p.msgInfo[id] if i.msgHandler != nil { i.msgHandler([]interface{}{msg, userData}) } if i.msgRouter != nil { i.msgRouter.Go(msgType, msg, userData) } return nil } // goroutine safe func (p *Processor) Unmarshal(data []byte) (interface{}, error) { if len(data) < 2 { return nil, errors.New("protobuf data too short") } // id var id uint16 if p.littleEndian { id = binary.LittleEndian.Uint16(data) } else { id = binary.BigEndian.Uint16(data) } if id >= uint16(len(p.msgInfo)) { return nil, fmt.Errorf("message id %v not registered", id) } // msg i := p.msgInfo[id] if i.msgRawHandler != nil { return MsgRaw{id, data[2:]}, nil } else { msg := reflect.New(i.msgType.Elem()).Interface() // return msg, proto.UnmarshalMerge(data[2:], msg.(proto.Message)) return msg, proto.Unmarshal(data[2:], msg.(proto.Message)) } } // goroutine safe func (p *Processor) Marshal(msg interface{}) ([][]byte, error) { msgType := reflect.TypeOf(msg) // id _id, ok := p.msgID[msgType] if !ok { err := fmt.Errorf("message %s not registered", msgType) return nil, err } id := make([]byte, 2) if p.littleEndian { binary.LittleEndian.PutUint16(id, _id) } else { binary.BigEndian.PutUint16(id, _id) } // data data, err := proto.Marshal(msg.(proto.Message)) return [][]byte{id, data}, err } // goroutine safe func (p *Processor) Range(f func(id uint16, t reflect.Type)) { for id, i := range p.msgInfo { f(uint16(id), i.msgType) } }