Optimized WebSocket Gateway For Micro-Services In Golang

Optimized WebSocket Gateway For Micro-Services In Golang

Using Epoll Linux Kernel System Call

WebSocket is a computer communications protocol, providing full-duplex communication channels over a single TCP connection. WebSocket is distinct from HTTP. Both protocols are located at layer 7 in OSI model and depend on TCP at layer 4. The only connection between HTTP and WebSocket is that WebSocket is designed to work over HTTP ports 443 and 80. That makes it compatible with HTTP. To achieve compatibility, the WebSocket handshake uses the HTTP Upgrade header to change from the HTTP protocol to the WebSocket protocol.

When we try to handle very large amount of WebSocket connection on a single node then we should only use WebSocket for data delivery i.e., WebSocket will take the data from user and transfer it to other services. It should not perform any process on data, it just take the data and transfer it the services that are sitting behind WebSocket. In Micro-Service architecture WebSocket should only be used as gateways. Data will come through WebSocket gateway in the micro-service architecture and then processed by responsible services.

Scaling WebSocket is a tough task given to the Backend Developers, because we have to deal with thousands of persistent TCP connection. Maintaining a TCP connection over a period consumes a lots of computer memory. So we have come up with a idea that has polling mechanism. Epoll suits best of this case. Epoll is a Linux Kernel System call for a scalable I/O event notification mechanism. Epoll consist of a set of user-space functions, each taking a file descriptor argument denoting the configurable kernel object, against which they cooperatively operate. Epoll uses a red-black tree data structure to keep track of all file descriptors that are currently being monitored.

This Method of scaling WebSocket is only work on UNIX based Operating Systems.

As mentioned above we will use Golang programming language for this tutorial. Golang is best suited language for this kind of application. Now let's start the coding part of this tutorial. I'm assuming that reader will have good understanding of Golang language because I'm not gonna explain language specific things in this tutorial.

First create a golang project using go mod init <PROJECT_NAME> and create main.go file. We will use gorilla library to handler WebSocket connection. Install it using go get github.com/gorilla/websocket.

Let's write package name and required imports and create two structs. One is MAIN which is used to connect all the parts of the code. Other is Epoll which is used to store information about server and client connection. Also create a global variable of type *Epoll.

package main

import (
    "fmt"
    "log"
    "net/http"
    "reflect"
    "sync"
    "syscall"

    "github.com/gorilla/websocket"
    "golang.org/x/sys/unix"
)

type MAIN struct {
    Lock          *sync.RWMutex
    Name          string
    EpollInstance *Epoll
}

type Epoll struct {
    fd          int
    connections map[int]websocket.Conn
    lock        *sync.RWMutex
}

var epoller *Epoll

Lets write Epoll controller functions with receiver MAIN. First function is MkEpoll, it create the Epoll instance and return a instance of Epoll struct that we have defined in above code snippet. Second function is Add which is used to to add a websocket connection for being monitored by Epoll. It takes `conn websocket.Conn as argument. Third function iswait, which is used to get the list of websocket connection FDs who are ready to send data to the server. Its return value is List of integers. The last function isRemove`, as the name suggest it is used to remove a websocket connection from Epoll.

func MkEpoll() (*Epoll, error) {
    fd, err := unix.EpollCreate1(0)
    if err != nil {
        return nil, err
    }
    return &Epoll{
        fd:          fd,
        lock:        &sync.RWMutex{},
        connections: make(map[int]websocket.Conn),
    }, nil
}

func (m *MAIN) Add(conn *websocket.Conn) error {
    // Extract file descriptor associated with the connection
    fd := m.websocketFD(conn)
    err := unix.EpollCtl(m.EpollInstance.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: unix.POLLIN | unix.POLLHUP, Fd: int32(fd)})
    if err != nil {
        return err
    }
    m.EpollInstance.lock.Lock()
    defer m.EpollInstance.lock.Unlock()
    m.EpollInstance.connections[fd] = *conn
    return nil
}

func (m *MAIN) Wait() ([]*websocket.Conn, error) {
    events := make([]unix.EpollEvent, 100)
    n, err := unix.EpollWait(m.EpollInstance.fd, events, 100)
    if err != nil {
        return nil, err
    }
    m.EpollInstance.lock.RLock()
    defer m.EpollInstance.lock.RUnlock()
    var connections []*websocket.Conn
    for i := 0; i < n; i++ {
        conn := m.EpollInstance.connections[int(events[i].Fd)]
        connections = append(connections, &conn)
    }
    return connections, nil
}

func (m *MAIN) Remove(conn websocket.Conn) error {
    fd := m.websocketFD(&conn)
    err := unix.EpollCtl(m.EpollInstance.fd, syscall.EPOLL_CTL_DEL, fd, nil)
    if err != nil {
        return err
    }
    m.EpollInstance.lock.Lock()
    defer m.EpollInstance.lock.Unlock()
    delete(m.EpollInstance.connections, fd)
    if len(m.EpollInstance.connections)%100 == 0 {
        log.Printf("Total number of connections: %v", len(m.EpollInstance.connections))
    }
    return nil
}

We also have to write a function that return file descriptor (fd) when *websocket.Conn is provided as argument to the function.

func (m *MAIN) websocketFD(conn *websocket.Conn) int {
    connVal := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn").Elem()
    tcpConn := reflect.Indirect(connVal).FieldByName("conn")
    fdVal := tcpConn.FieldByName("fd")
    pfdVal := reflect.Indirect(fdVal).FieldByName("pfd")
    return int(pfdVal.FieldByName("Sysfd").Int())
}

Function for upgrading the HTTP request to the WebSocket

func (m *MAIN) wsHandler(w http.ResponseWriter, r *http.Request) {
    // Upgrade connection
    upgrader := websocket.Upgrader{}
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        return
    }
    if err := m.Add(conn); err != nil {
        log.Println("Failed to add connection: ", err)
        conn.Close()
    }
}

Last but not the least we have to write a function that will get executed inside a goroutine and handle the incoming data send by the clients.

func (m *MAIN) Start() {
    var count int = 0
    for {
        connections, err := m.Wait()
        if err != nil {
            fmt.Println("Error while epollWait")
            continue
        }

        for _, conn := range connections {
            if conn == nil {
                break
            }
            // first '_' is the message type and second '_' is the actual data
            // so if you want to print the incoming data, then initialize the second 
            // '_' with variable and then print it.
            if _, _, err := conn.ReadMessage(); err != nil {
                m.Remove(*conn)
                fmt.Println("[ERROR] -->> ", err.Error())
            } else {
                // fmt.Println("Count: ", count)
                count++
            }
        }
    }
}

If you want to write back to the client then create a function that takes data from other microservices using grpc, message queues like rabbitMQ, etc and create a method through which you can find to whome the data need to write and then access the client websocket.Conn object from var conn WebSocket.conn = MAIN.EpollInstance[<fd>].

Now, lets create the main function and write required code that are needed to execute code that are mentioned above. In main function, first of all we increasing the limit of open file descriptor to max, then we declare variable of type MAIN. Then we call the MkEpoll() function that will initialize the Epoll an return a instance of Epoll struct. We will assign the m.EpollInstance will the return value of MKEpoll function. then call go m.start(). And at last we will listen to the home route to start the WebSocket handler.

If you want handle lots of WebSocket connection then we have to increase the upper limit for how much open file descriptor are allowed.

func main() {
    fmt.Println("Websocket Optimization Test")

    // Increase resources limitations
    go func() {
        var rLimit syscall.Rlimit
        if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil {
            panic(err)
        }
        rLimit.Cur = rLimit.Max
        if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil {
            panic(err)
        }
    }()

    var m MAIN

    // Start epoll
    var err error
    epoller, err = MkEpoll()
    if err != nil {
        panic(err)
    }
    m.EpollInstance = epoller
    go m.Start()

    // starting websocket handler
    http.HandleFunc("/", m.wsHandler)
    if err := http.ListenAndServe(":8000", nil); err != nil {
        log.Fatal(err)
    }
}

Complete Code:

package main

import (
    "fmt"
    "log"
    "net/http"
    "reflect"
    "sync"
    "syscall"

    "github.com/gorilla/websocket"
    "golang.org/x/sys/unix"
)

type MAIN struct {
    Lock          *sync.RWMutex
    Name          string
    EpollInstance *Epoll
}

type Epoll struct {
    fd          int
    connections map[int]websocket.Conn
    lock        *sync.RWMutex
}

var (
    epoller *Epoll
)

func MkEpoll() (*Epoll, error) {
    fd, err := unix.EpollCreate1(0)
    if err != nil {
        return nil, err
    }
    return &Epoll{
        fd:          fd,
        lock:        &sync.RWMutex{},
        connections: make(map[int]websocket.Conn),
    }, nil
}

func (m *MAIN) websocketFD(conn *websocket.Conn) int {
    connVal := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn").Elem()
    tcpConn := reflect.Indirect(connVal).FieldByName("conn")
    fdVal := tcpConn.FieldByName("fd")
    pfdVal := reflect.Indirect(fdVal).FieldByName("pfd")
    return int(pfdVal.FieldByName("Sysfd").Int())
}

func (m *MAIN) Add(conn *websocket.Conn) error {
    // Extract file descriptor associated with the connection
    fd := m.websocketFD(conn)
    err := unix.EpollCtl(m.EpollInstance.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: unix.POLLIN | unix.POLLHUP, Fd: int32(fd)})
    if err != nil {
        return err
    }
    m.EpollInstance.lock.Lock()
    defer m.EpollInstance.lock.Unlock()
    m.EpollInstance.connections[fd] = *conn
    return nil
}

func (m *MAIN) Wait() ([]*websocket.Conn, error) {
    events := make([]unix.EpollEvent, 100)
    n, err := unix.EpollWait(m.EpollInstance.fd, events, 100)
    if err != nil {
        return nil, err
    }
    m.EpollInstance.lock.RLock()
    defer m.EpollInstance.lock.RUnlock()
    var connections []*websocket.Conn
    for i := 0; i < n; i++ {
        conn := m.EpollInstance.connections[int(events[i].Fd)]
        connections = append(connections, &conn)
    }
    return connections, nil
}

func (m *MAIN) wsHandler(w http.ResponseWriter, r *http.Request) {
    // Upgrade connection
    upgrader := websocket.Upgrader{}
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        return
    }
    if err := m.Add(conn); err != nil {
        log.Println("Failed to add connection: ", err)
        conn.Close()
    }
}

func (m *MAIN) Remove(conn websocket.Conn) error {
    fd := m.websocketFD(&conn)
    err := unix.EpollCtl(m.EpollInstance.fd, syscall.EPOLL_CTL_DEL, fd, nil)
    if err != nil {
        return err
    }
    m.EpollInstance.lock.Lock()
    defer m.EpollInstance.lock.Unlock()
    delete(m.EpollInstance.connections, fd)
    if len(m.EpollInstance.connections)%100 == 0 {
        log.Printf("Total number of connections: %v", len(m.EpollInstance.connections))
    }
    return nil
}

func (m *MAIN) Start() {
    var count int = 0
    for {
        connections, err := m.Wait()
        if err != nil {
            fmt.Println("Error while epollWait")
            continue
        }

        for _, conn := range connections {
            if conn == nil {
                break
            }
            if _, _, err := conn.ReadMessage(); err != nil {
                m.Remove(*conn)
                fmt.Println("[ERROR] -->> ", err.Error())
            } else {
                // fmt.Println("Count: ", count)
                count++
            }
        }
    }
}

func main() {
    fmt.Println("Websocket Optimization Test")

    // Increase resources limitations
    go func() {
        var rLimit syscall.Rlimit
        if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil {
            panic(err)
        }
        rLimit.Cur = rLimit.Max
        if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil {
            panic(err)
        }
    }()

    // Enable pprof hooks
    go func() {
        if err := http.ListenAndServe("localhost:6060", nil); err != nil {
            log.Fatalf("pprof failed: %v", err)
        }
    }()

    var m MAIN

    // Start epoll
    var err error
    epoller, err = MkEpoll()
    if err != nil {
        panic(err)
    }
    m.EpollInstance = epoller
    go m.Start()

    // starting websocket handler
    http.HandleFunc("/", m.wsHandler)
    if err := http.ListenAndServe(":8000", nil); err != nil {
        log.Fatal(err)
    }
}

Thank You for reading till the last :)

Feel free to point out any mistakes