package network import ( "crypto/tls" "sync" "time" "leafstalk/log" "github.com/gorilla/websocket" ) type WSClient struct { sync.Mutex Addr string ConnNum int ConnectInterval time.Duration PendingWriteNum int MaxMsgLen uint32 HandshakeTimeout time.Duration AutoReconnect bool NewAgent func(*WSConn, interface{}) Agent dialer websocket.Dialer conns WebsocketConnSet wg sync.WaitGroup closeFlag bool } func (client *WSClient) Start() { client.init() for i := 0; i < client.ConnNum; i++ { client.wg.Add(1) go client.connect(client.Addr, nil) } } func (client *WSClient) init() { client.Lock() defer client.Unlock() if client.ConnNum <= 0 { client.ConnNum = 1 log.Infof("invalid ConnNum, reset to %v", client.ConnNum) } if client.ConnectInterval <= 0 { client.ConnectInterval = 3 * time.Second log.Infof("invalid ConnectInterval, reset to %v", client.ConnectInterval) } if client.PendingWriteNum <= 0 { client.PendingWriteNum = 100 log.Infof("invalid PendingWriteNum, reset to %v", client.PendingWriteNum) } if client.MaxMsgLen <= 0 { client.MaxMsgLen = 4096 log.Infof("invalid MaxMsgLen, reset to %v", client.MaxMsgLen) } if client.HandshakeTimeout <= 0 { client.HandshakeTimeout = 10 * time.Second log.Infof("invalid HandshakeTimeout, reset to %v", client.HandshakeTimeout) } if client.NewAgent == nil { log.Fatal("NewAgent must not be nil") } if client.conns != nil { log.Fatal("client is running") } client.conns = make(WebsocketConnSet) client.closeFlag = false client.dialer = websocket.Dialer{ HandshakeTimeout: client.HandshakeTimeout, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } } func (client *WSClient) dial(addr string) *websocket.Conn { for { conn, _, err := client.dialer.Dial(addr, nil) if err == nil || client.closeFlag { return conn } log.Infof("connect to %v error: %v", addr, err) time.Sleep(client.ConnectInterval) continue } } func (client *WSClient) connect(addr string, data interface{}) { defer client.wg.Done() reconnect: conn := client.dial(addr) if conn == nil { return } conn.SetReadLimit(int64(client.MaxMsgLen)) client.Lock() if client.closeFlag { client.Unlock() conn.Close() return } client.conns[conn] = struct{}{} client.Unlock() wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen) agent := client.NewAgent(wsConn, data) agent.Run() // cleanup wsConn.Close() client.Lock() delete(client.conns, conn) client.Unlock() agent.OnClose() if client.AutoReconnect { time.Sleep(client.ConnectInterval) goto reconnect } } func (client *WSClient) Close() { client.Lock() client.closeFlag = true for conn := range client.conns { conn.Close() } client.conns = nil client.Unlock() client.wg.Wait() } func (client *WSClient) Init() { client.init() } func (client *WSClient) StartNewClient(addr string, data interface{}) { client.wg.Add(1) go client.connect(addr, data) }