ws_server.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package network
  2. import (
  3. "crypto/tls"
  4. "leafstalk/log"
  5. "net"
  6. "net/http"
  7. "sync"
  8. "time"
  9. "github.com/gorilla/websocket"
  10. )
  11. const (
  12. XForwardedFor = "X-Forwarded-For"
  13. XRealIP = "X-Real-IP"
  14. )
  15. type WSServer struct {
  16. Addr string
  17. MaxConnNum int
  18. PendingWriteNum int
  19. MaxMsgLen uint32
  20. HTTPTimeout time.Duration
  21. CertFile string
  22. KeyFile string
  23. NewAgent func(*WSConn) Agent
  24. ln net.Listener
  25. handler *WSHandler
  26. }
  27. type WSHandler struct {
  28. maxConnNum int
  29. pendingWriteNum int
  30. maxMsgLen uint32
  31. newAgent func(*WSConn) Agent
  32. upgrader websocket.Upgrader
  33. conns WebsocketConnSet
  34. mutexConns sync.Mutex
  35. wg sync.WaitGroup
  36. }
  37. func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  38. if r.Method != "GET" {
  39. http.Error(w, "Method not allowed", 405)
  40. return
  41. }
  42. conn, err := handler.upgrader.Upgrade(w, r, nil)
  43. if err != nil {
  44. log.Debugf("upgrade error: %v", err)
  45. return
  46. }
  47. conn.SetReadLimit(int64(handler.maxMsgLen))
  48. handler.wg.Add(1)
  49. defer handler.wg.Done()
  50. handler.mutexConns.Lock()
  51. if handler.conns == nil {
  52. handler.mutexConns.Unlock()
  53. conn.Close()
  54. return
  55. }
  56. if len(handler.conns) >= handler.maxConnNum {
  57. handler.mutexConns.Unlock()
  58. conn.Close()
  59. log.Debug("too many connections")
  60. return
  61. }
  62. handler.conns[conn] = struct{}{}
  63. handler.mutexConns.Unlock()
  64. ip := RemoteIp(r)
  65. wsConn := newWSConn(conn, handler.pendingWriteNum, handler.maxMsgLen)
  66. wsConn.ip = ip
  67. agent := handler.newAgent(wsConn)
  68. agent.Run()
  69. // cleanup
  70. wsConn.Close()
  71. handler.mutexConns.Lock()
  72. delete(handler.conns, conn)
  73. handler.mutexConns.Unlock()
  74. agent.OnClose()
  75. }
  76. func (server *WSServer) Start() {
  77. ln, err := net.Listen("tcp", server.Addr)
  78. if err != nil {
  79. log.Fatalf("%v", err)
  80. }
  81. if server.MaxConnNum <= 0 {
  82. server.MaxConnNum = 100
  83. log.Infof("invalid MaxConnNum, reset to %v", server.MaxConnNum)
  84. }
  85. if server.PendingWriteNum <= 0 {
  86. server.PendingWriteNum = 100
  87. log.Infof("invalid PendingWriteNum, reset to %v", server.PendingWriteNum)
  88. }
  89. if server.MaxMsgLen <= 0 {
  90. server.MaxMsgLen = 4096
  91. log.Infof("invalid MaxMsgLen, reset to %v", server.MaxMsgLen)
  92. }
  93. if server.HTTPTimeout <= 0 {
  94. server.HTTPTimeout = 10 * time.Second
  95. log.Infof("invalid HTTPTimeout, reset to %v", server.HTTPTimeout)
  96. }
  97. if server.NewAgent == nil {
  98. log.Fatal("NewAgent must not be nil")
  99. }
  100. if server.CertFile != "" || server.KeyFile != "" {
  101. config := &tls.Config{}
  102. config.NextProtos = []string{"http/1.1"}
  103. var err error
  104. config.Certificates = make([]tls.Certificate, 1)
  105. config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile)
  106. if err != nil {
  107. log.Fatalf("%v", err)
  108. }
  109. ln = tls.NewListener(ln, config)
  110. }
  111. server.ln = ln
  112. server.handler = &WSHandler{
  113. maxConnNum: server.MaxConnNum,
  114. pendingWriteNum: server.PendingWriteNum,
  115. maxMsgLen: server.MaxMsgLen,
  116. newAgent: server.NewAgent,
  117. conns: make(WebsocketConnSet),
  118. upgrader: websocket.Upgrader{
  119. HandshakeTimeout: server.HTTPTimeout,
  120. CheckOrigin: func(_ *http.Request) bool { return true },
  121. },
  122. }
  123. httpServer := &http.Server{
  124. Addr: server.Addr,
  125. Handler: server.handler,
  126. ReadTimeout: server.HTTPTimeout,
  127. WriteTimeout: server.HTTPTimeout,
  128. MaxHeaderBytes: 1024,
  129. }
  130. go httpServer.Serve(ln)
  131. }
  132. func (server *WSServer) Close() {
  133. server.ln.Close()
  134. server.handler.mutexConns.Lock()
  135. for conn := range server.handler.conns {
  136. conn.Close()
  137. }
  138. server.handler.conns = nil
  139. server.handler.mutexConns.Unlock()
  140. server.handler.wg.Wait()
  141. }
  142. func RemoteIp(req *http.Request) string {
  143. remoteAddr := req.RemoteAddr
  144. if ip := req.Header.Get(XRealIP); ip != "" {
  145. remoteAddr = ip
  146. } else if ip = req.Header.Get(XForwardedFor); ip != "" {
  147. remoteAddr = ip
  148. } else {
  149. remoteAddr, _, _ = net.SplitHostPort(remoteAddr)
  150. }
  151. if remoteAddr == "::1" {
  152. remoteAddr = "127.0.0.1"
  153. }
  154. return remoteAddr
  155. }