123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- package network
- import (
- "crypto/tls"
- "leafstalk/log"
- "net"
- "net/http"
- "sync"
- "time"
- "github.com/gorilla/websocket"
- )
- const (
- XForwardedFor = "X-Forwarded-For"
- XRealIP = "X-Real-IP"
- )
- type WSServer struct {
- Addr string
- MaxConnNum int
- PendingWriteNum int
- MaxMsgLen uint32
- HTTPTimeout time.Duration
- CertFile string
- KeyFile string
- NewAgent func(*WSConn) Agent
- ln net.Listener
- handler *WSHandler
- }
- type WSHandler struct {
- maxConnNum int
- pendingWriteNum int
- maxMsgLen uint32
- newAgent func(*WSConn) Agent
- upgrader websocket.Upgrader
- conns WebsocketConnSet
- mutexConns sync.Mutex
- wg sync.WaitGroup
- }
- func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- if r.Method != "GET" {
- http.Error(w, "Method not allowed", 405)
- return
- }
- conn, err := handler.upgrader.Upgrade(w, r, nil)
- if err != nil {
- log.Debugf("upgrade error: %v", err)
- return
- }
- conn.SetReadLimit(int64(handler.maxMsgLen))
- handler.wg.Add(1)
- defer handler.wg.Done()
- handler.mutexConns.Lock()
- if handler.conns == nil {
- handler.mutexConns.Unlock()
- conn.Close()
- return
- }
- if len(handler.conns) >= handler.maxConnNum {
- handler.mutexConns.Unlock()
- conn.Close()
- log.Debug("too many connections")
- return
- }
- handler.conns[conn] = struct{}{}
- handler.mutexConns.Unlock()
- ip := RemoteIp(r)
- wsConn := newWSConn(conn, handler.pendingWriteNum, handler.maxMsgLen)
- wsConn.ip = ip
- agent := handler.newAgent(wsConn)
- agent.Run()
- // cleanup
- wsConn.Close()
- handler.mutexConns.Lock()
- delete(handler.conns, conn)
- handler.mutexConns.Unlock()
- agent.OnClose()
- }
- func (server *WSServer) Start() {
- ln, err := net.Listen("tcp", server.Addr)
- if err != nil {
- log.Fatalf("%v", err)
- }
- if server.MaxConnNum <= 0 {
- server.MaxConnNum = 100
- log.Infof("invalid MaxConnNum, reset to %v", server.MaxConnNum)
- }
- if server.PendingWriteNum <= 0 {
- server.PendingWriteNum = 100
- log.Infof("invalid PendingWriteNum, reset to %v", server.PendingWriteNum)
- }
- if server.MaxMsgLen <= 0 {
- server.MaxMsgLen = 4096
- log.Infof("invalid MaxMsgLen, reset to %v", server.MaxMsgLen)
- }
- if server.HTTPTimeout <= 0 {
- server.HTTPTimeout = 10 * time.Second
- log.Infof("invalid HTTPTimeout, reset to %v", server.HTTPTimeout)
- }
- if server.NewAgent == nil {
- log.Fatal("NewAgent must not be nil")
- }
- if server.CertFile != "" || server.KeyFile != "" {
- config := &tls.Config{}
- config.NextProtos = []string{"http/1.1"}
- var err error
- config.Certificates = make([]tls.Certificate, 1)
- config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile)
- if err != nil {
- log.Fatalf("%v", err)
- }
- ln = tls.NewListener(ln, config)
- }
- server.ln = ln
- server.handler = &WSHandler{
- maxConnNum: server.MaxConnNum,
- pendingWriteNum: server.PendingWriteNum,
- maxMsgLen: server.MaxMsgLen,
- newAgent: server.NewAgent,
- conns: make(WebsocketConnSet),
- upgrader: websocket.Upgrader{
- HandshakeTimeout: server.HTTPTimeout,
- CheckOrigin: func(_ *http.Request) bool { return true },
- },
- }
- httpServer := &http.Server{
- Addr: server.Addr,
- Handler: server.handler,
- ReadTimeout: server.HTTPTimeout,
- WriteTimeout: server.HTTPTimeout,
- MaxHeaderBytes: 1024,
- }
- go httpServer.Serve(ln)
- }
- func (server *WSServer) Close() {
- server.ln.Close()
- server.handler.mutexConns.Lock()
- for conn := range server.handler.conns {
- conn.Close()
- }
- server.handler.conns = nil
- server.handler.mutexConns.Unlock()
- server.handler.wg.Wait()
- }
- func RemoteIp(req *http.Request) string {
- remoteAddr := req.RemoteAddr
- if ip := req.Header.Get(XRealIP); ip != "" {
- remoteAddr = ip
- } else if ip = req.Header.Get(XForwardedFor); ip != "" {
- remoteAddr = ip
- } else {
- remoteAddr, _, _ = net.SplitHostPort(remoteAddr)
- }
- if remoteAddr == "::1" {
- remoteAddr = "127.0.0.1"
- }
- return remoteAddr
- }
|