package network import ( "errors" "net" "sync" "leafstalk/log" "leafstalk/otherutils" "github.com/gorilla/websocket" ) type WebsocketConnSet map[*websocket.Conn]struct{} type WSConn struct { sync.Mutex conn *websocket.Conn writeChan chan []byte maxMsgLen uint32 closeFlag bool ip string } func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32) *WSConn { wsConn := new(WSConn) wsConn.conn = conn wsConn.writeChan = make(chan []byte, pendingWriteNum) wsConn.maxMsgLen = maxMsgLen go func() { for b := range wsConn.writeChan { if b == nil { break } b = EncryptByXxtea(b) err := conn.WriteMessage(websocket.BinaryMessage, b) if err != nil { break } } conn.Close() wsConn.Lock() wsConn.closeFlag = true wsConn.Unlock() }() return wsConn } func (wsConn *WSConn) doDestroy() { // wsConn.conn.UnderlyingConn().(*tls.Conn).SetLinger(0) // wsConn.conn.UnderlyingConn() wsConn.conn.Close() if !wsConn.closeFlag { close(wsConn.writeChan) wsConn.closeFlag = true } } func (wsConn *WSConn) Destroy() { wsConn.Lock() defer wsConn.Unlock() wsConn.doDestroy() } func (wsConn *WSConn) Close() { wsConn.Lock() defer wsConn.Unlock() if wsConn.closeFlag { return } wsConn.doWrite(nil) wsConn.closeFlag = true } func (wsConn *WSConn) doWrite(b []byte) { if len(wsConn.writeChan) == cap(wsConn.writeChan) { log.Debug("close conn: channel full") wsConn.doDestroy() return } wsConn.writeChan <- b } func (wsConn *WSConn) LocalAddr() net.Addr { return wsConn.conn.LocalAddr() } func (wsConn *WSConn) RemoteAddr() net.Addr { if len(wsConn.ip) > 0 { ip, err := net.ResolveIPAddr("ip", wsConn.ip) if err == nil { return ip } } return wsConn.conn.RemoteAddr() } // goroutine not safe func (wsConn *WSConn) ReadMsg() ([]byte, error) { _, b, err := wsConn.conn.ReadMessage() //log.Debugf("ReadMsg len %d", len(b)) if b != nil { b = DecryptByXxtea(b) } return b, err } // args must not be modified by the others goroutines func (wsConn *WSConn) WriteMsg(args ...[]byte) error { wsConn.Lock() defer wsConn.Unlock() if wsConn.closeFlag { return nil } // get len var msgLen uint32 for i := 0; i < len(args); i++ { msgLen += uint32(len(args[i])) } // check len if msgLen > wsConn.maxMsgLen { return errors.New("message too long") } else if msgLen < 1 { return errors.New("message too short") } // don't copy if len(args) == 1 { wsConn.doWrite(args[0]) return nil } // merge the args msg := make([]byte, msgLen) l := 0 for i := 0; i < len(args); i++ { copy(msg[l:], args[i]) l += len(args[i]) } wsConn.doWrite(msg) return nil } var ( XXTEA_KEY = "covenant" ) func SetXxteaPass(pass string) { XXTEA_KEY = pass } // 加密方法 func EncryptByXxtea(data []byte) []byte { // return data return otherutils.Encrypt(data, []byte(XXTEA_KEY)) } // 解密方法 func DecryptByXxtea(data []byte) []byte { // return data return otherutils.Decrypt(data, []byte(XXTEA_KEY)) }