ws_client.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package network
  2. import (
  3. "crypto/tls"
  4. "sync"
  5. "time"
  6. "leafstalk/log"
  7. "github.com/gorilla/websocket"
  8. )
  9. type WSClient struct {
  10. sync.Mutex
  11. Addr string
  12. ConnNum int
  13. ConnectInterval time.Duration
  14. PendingWriteNum int
  15. MaxMsgLen uint32
  16. HandshakeTimeout time.Duration
  17. AutoReconnect bool
  18. NewAgent func(*WSConn, interface{}) Agent
  19. dialer websocket.Dialer
  20. conns WebsocketConnSet
  21. wg sync.WaitGroup
  22. closeFlag bool
  23. }
  24. func (client *WSClient) Start() {
  25. client.init()
  26. for i := 0; i < client.ConnNum; i++ {
  27. client.wg.Add(1)
  28. go client.connect(client.Addr, nil)
  29. }
  30. }
  31. func (client *WSClient) init() {
  32. client.Lock()
  33. defer client.Unlock()
  34. if client.ConnNum <= 0 {
  35. client.ConnNum = 1
  36. log.Infof("invalid ConnNum, reset to %v", client.ConnNum)
  37. }
  38. if client.ConnectInterval <= 0 {
  39. client.ConnectInterval = 3 * time.Second
  40. log.Infof("invalid ConnectInterval, reset to %v", client.ConnectInterval)
  41. }
  42. if client.PendingWriteNum <= 0 {
  43. client.PendingWriteNum = 100
  44. log.Infof("invalid PendingWriteNum, reset to %v", client.PendingWriteNum)
  45. }
  46. if client.MaxMsgLen <= 0 {
  47. client.MaxMsgLen = 4096
  48. log.Infof("invalid MaxMsgLen, reset to %v", client.MaxMsgLen)
  49. }
  50. if client.HandshakeTimeout <= 0 {
  51. client.HandshakeTimeout = 10 * time.Second
  52. log.Infof("invalid HandshakeTimeout, reset to %v", client.HandshakeTimeout)
  53. }
  54. if client.NewAgent == nil {
  55. log.Fatal("NewAgent must not be nil")
  56. }
  57. if client.conns != nil {
  58. log.Fatal("client is running")
  59. }
  60. client.conns = make(WebsocketConnSet)
  61. client.closeFlag = false
  62. client.dialer = websocket.Dialer{
  63. HandshakeTimeout: client.HandshakeTimeout,
  64. TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
  65. }
  66. }
  67. func (client *WSClient) dial(addr string) *websocket.Conn {
  68. for {
  69. conn, _, err := client.dialer.Dial(addr, nil)
  70. if err == nil || client.closeFlag {
  71. return conn
  72. }
  73. log.Infof("connect to %v error: %v", addr, err)
  74. time.Sleep(client.ConnectInterval)
  75. continue
  76. }
  77. }
  78. func (client *WSClient) connect(addr string, data interface{}) {
  79. defer client.wg.Done()
  80. reconnect:
  81. conn := client.dial(addr)
  82. if conn == nil {
  83. return
  84. }
  85. conn.SetReadLimit(int64(client.MaxMsgLen))
  86. client.Lock()
  87. if client.closeFlag {
  88. client.Unlock()
  89. conn.Close()
  90. return
  91. }
  92. client.conns[conn] = struct{}{}
  93. client.Unlock()
  94. wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen)
  95. agent := client.NewAgent(wsConn, data)
  96. agent.Run()
  97. // cleanup
  98. wsConn.Close()
  99. client.Lock()
  100. delete(client.conns, conn)
  101. client.Unlock()
  102. agent.OnClose()
  103. if client.AutoReconnect {
  104. time.Sleep(client.ConnectInterval)
  105. goto reconnect
  106. }
  107. }
  108. func (client *WSClient) Close() {
  109. client.Lock()
  110. client.closeFlag = true
  111. for conn := range client.conns {
  112. conn.Close()
  113. }
  114. client.conns = nil
  115. client.Unlock()
  116. client.wg.Wait()
  117. }
  118. func (client *WSClient) Init() {
  119. client.init()
  120. }
  121. func (client *WSClient) StartNewClient(addr string, data interface{}) {
  122. client.wg.Add(1)
  123. go client.connect(addr, data)
  124. }