123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- package network
- import (
- "crypto/tls"
- "sync"
- "time"
- "leafstalk/log"
- "github.com/gorilla/websocket"
- )
- type WSClient struct {
- sync.Mutex
- Addr string
- ConnNum int
- ConnectInterval time.Duration
- PendingWriteNum int
- MaxMsgLen uint32
- HandshakeTimeout time.Duration
- AutoReconnect bool
- NewAgent func(*WSConn, interface{}) Agent
- dialer websocket.Dialer
- conns WebsocketConnSet
- wg sync.WaitGroup
- closeFlag bool
- }
- func (client *WSClient) Start() {
- client.init()
- for i := 0; i < client.ConnNum; i++ {
- client.wg.Add(1)
- go client.connect(client.Addr, nil)
- }
- }
- func (client *WSClient) init() {
- client.Lock()
- defer client.Unlock()
- if client.ConnNum <= 0 {
- client.ConnNum = 1
- log.Infof("invalid ConnNum, reset to %v", client.ConnNum)
- }
- if client.ConnectInterval <= 0 {
- client.ConnectInterval = 3 * time.Second
- log.Infof("invalid ConnectInterval, reset to %v", client.ConnectInterval)
- }
- if client.PendingWriteNum <= 0 {
- client.PendingWriteNum = 100
- log.Infof("invalid PendingWriteNum, reset to %v", client.PendingWriteNum)
- }
- if client.MaxMsgLen <= 0 {
- client.MaxMsgLen = 4096
- log.Infof("invalid MaxMsgLen, reset to %v", client.MaxMsgLen)
- }
- if client.HandshakeTimeout <= 0 {
- client.HandshakeTimeout = 10 * time.Second
- log.Infof("invalid HandshakeTimeout, reset to %v", client.HandshakeTimeout)
- }
- if client.NewAgent == nil {
- log.Fatal("NewAgent must not be nil")
- }
- if client.conns != nil {
- log.Fatal("client is running")
- }
- client.conns = make(WebsocketConnSet)
- client.closeFlag = false
- client.dialer = websocket.Dialer{
- HandshakeTimeout: client.HandshakeTimeout,
- TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
- }
- }
- func (client *WSClient) dial(addr string) *websocket.Conn {
- for {
- conn, _, err := client.dialer.Dial(addr, nil)
- if err == nil || client.closeFlag {
- return conn
- }
- log.Infof("connect to %v error: %v", addr, err)
- time.Sleep(client.ConnectInterval)
- continue
- }
- }
- func (client *WSClient) connect(addr string, data interface{}) {
- defer client.wg.Done()
- reconnect:
- conn := client.dial(addr)
- if conn == nil {
- return
- }
- conn.SetReadLimit(int64(client.MaxMsgLen))
- client.Lock()
- if client.closeFlag {
- client.Unlock()
- conn.Close()
- return
- }
- client.conns[conn] = struct{}{}
- client.Unlock()
- wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen)
- agent := client.NewAgent(wsConn, data)
- agent.Run()
- // cleanup
- wsConn.Close()
- client.Lock()
- delete(client.conns, conn)
- client.Unlock()
- agent.OnClose()
- if client.AutoReconnect {
- time.Sleep(client.ConnectInterval)
- goto reconnect
- }
- }
- func (client *WSClient) Close() {
- client.Lock()
- client.closeFlag = true
- for conn := range client.conns {
- conn.Close()
- }
- client.conns = nil
- client.Unlock()
- client.wg.Wait()
- }
- func (client *WSClient) Init() {
- client.init()
- }
- func (client *WSClient) StartNewClient(addr string, data interface{}) {
- client.wg.Add(1)
- go client.connect(addr, data)
- }
|