Rate Limiting using Redis + Golang (Token Bucket algorithm)

Rate Limiting using Redis + Golang (Token Bucket algorithm)

What is rate limiting?

Imagine someone making a request to an API. Now, let's say there are many requests coming in from different people or applications, and the server that handles these requests gets overwhelmed. It can't handle all the requests at once, which can lead to poor performance, crashes, or even security issues. To prevent this, rate limits come into play. Rate limits are like traffic rules for API requests. They help the server maintain a balanced flow of requests and ensure fair usage for everyone. Just as traffic lights regulate the number of cars passing through an intersection, rate limits control the number of API requests allowed within a certain time period.

What will happen if we don't use rate limiting?

If rate limiting is not implemented, several issues can arise, leading to undesirable consequences. Here are some potential outcomes of not using rate limiting

  1. Server Overload: Rate limiting prevents server overload by limiting the number of requests that can be processed within a certain time. Without rate limits, a sudden surge of requests can overwhelm the server, causing it to slow down or crash. Rate limits ensure that the server can handle requests in a controlled and efficient manner.

  2. Denial of Service (DoS) Attacks: Rate limiting acts as a defense against Denial of Service attacks. These attacks aim to flood a server with a massive number of requests, rendering it unavailable to legitimate users. By implementing rate limits, excessive requests can be detected and blocked, mitigating the impact of such attacks and safeguarding the server's availability.

  3. API Abuse and Misuse: Without rate limits, malicious users or automated bots can abuse the API by sending a large number of requests in a short period. This can result in data scraping, unauthorized access attempts, or other forms of abuse. API providers may incur higher operational costs, experience security vulnerabilities, or even face legal consequences if sensitive data is compromised.

What is Token Bucket Algorithm?

there are many algorithms out there for rate limiting. For this blog, I chose the Token Bucket algorithm cause it's a good option for handling burst traffic.

Imagine you have a bucket that can hold a certain number of tokens. Tokens represent permission to perform certain actions, such as sending data. As new tokens are added to the bucket over time, you can control how frequently someone can perform actions. When they want to perform an action, they need to take a token from the bucket. If tokens are available, they can proceed. However, if there are no tokens in the bucket, they have to wait until new tokens are added. For example, let's say the bucket is refilled with one token every second, and the bucket can hold a maximum of 10 tokens. This means that someone can perform the action up to 10 times in a row because the tokens are available. But if they try to perform the action more than 10 times quickly, they have to wait until new tokens are added to the bucket. The token bucket algorithm helps control the rate at which actions can be performed. It ensures that there is a balance between allowing actions and preventing an excessive number of actions from happening too quickly. This helps maintain fairness, prevent overloading a system, and protect against abuse.

let's Implement this using Redis + Golang

In this blog, we will implement a middleware for our API routes that will handle the traffic for us and we will be using Redis and Gin web framework in Golang.

Setup Redis

To ensure an efficient and reproducible Redis setup, I am utilizing a Docker image through Docker Compose instead of manual configuration on my local machine. This approach simplifies the setup process and allows for consistent Redis deployments across different environments.

version: "3.9"
services:
  redis:
    image: "redis:alpine"
    ports:
      - "6379:6379"

Creating the middleware

Now, let's dive into writing Golang code and get our hands dirty. First, we need the IP address of a user as an identifier to track and enforce rate limits on a per-client basis so in this code so I can retrieve the client's IP address from the request header( U can use ). Then executed the rate limit logic using the Redis script which is written in a file called "script.lua" (Why I have used the Redis script you will find out soon). Also, we need to pass the bucket capacity(max token that it can hold ), and refill rate of the token at a certain time and the current time. This will return a result which can be either 0 or 1. On result = 1, we will execute the request otherwise the middleware responds with an HTTP status code of 429 (Too Many Requests) and a JSON message indicating the request overflowed.

here's the code: -

package main

import (
    "fmt"
    "net/http"
    "os"
    "time"

    "github.com/gin-gonic/gin"
    "github.com/go-redis/redis/v8"
)

type RedisClient struct {
    client *redis.Client
}

var (
    redisClient *RedisClient
)

const (
    redisAddress = "localhost:6379"
)

func GetRedisClient() *RedisClient {

    redisClient = &RedisClient{
        client: redis.NewClient(&redis.Options{
            Addr: redisAddress,
            DB:   0,
        }),
    }

    return redisClient
}

func (rc *RedisClient) RateLimitMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {

        // Get the IP address
        IPAddress := c.GetHeader("X-Real-Ip")
        if IPAddress == "" {
            IPAddress = c.GetHeader("X-Forwarded-For")
        }
        if IPAddress == "" {
            IPAddress = c.Request.RemoteAddr
        }
        fmt.Println(IPAddress)

        //getting the content of the lua script
        script, err := os.ReadFile("script.lua")

        if err != nil {

            c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"status": false, "message": "unable to read script"})
        }

        var takeScript = redis.NewScript(string(script))
        const rate = 10     // 10 per second
        const capacity = 10 // burst of up to 10

        now := time.Now().UnixMicro()
        res, err := takeScript.Run(c, rc.client, []string{IPAddress}, capacity, rate, now, 1).Result()
        if err != nil {
            panic(err)
        }

        allowed := res.(int64)
        fmt.Println(allowed)
        if allowed != 1 {
            // request will be aborted
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"status": false, "message": "request overflowed"})
        }

        c.Next()
    }
}

func PingHandler(c *gin.Context) {
    c.String(http.StatusOK, "Pong")
}

func main() {
    redisClient := GetRedisClient()

    router := gin.Default()
    router.Use(redisClient.RateLimitMiddleware())

    router.GET("/ping", PingHandler)

    if err := router.Run(":8080"); err != nil {
        fmt.Println("Failed to start server:", err)
    }
}

Implementing Token Bucket Algorithm

The script takes several input parameters: tokens_key and last_access_key are keys used to store token count and last access time in Redis. capacity specifies the maximum number of tokens the bucket can hold. the rate represents the rate of token generation in tokens per second. now is the current timestamp in microseconds, and requested is the number of tokens requested for the operation. Last three arguments

Here's a brief overview of how the script works:

  1. It fetches the current token count (last_tokens) and last access time (last_access) from Redis. If they don't exist, it initializes them with default values (capacity for tokens and 0 for last access time).

  2. The script calculates the elapsed time since the last access and the number of tokens to be added (add_tokens) based on the elapsed time and the rate of token generation. The number of tokens to be added is capped at the bucket's capacity.

  3. It calculates the new token count (new_tokens) by adding the tokens to the previous count. It ensures that the new token count does not exceed the bucket's capacity.

  4. The script determines the new last access time (new_access_time) by adding the appropriate time based on the number of tokens added.

  5. It checks if enough tokens have been accumulated (new_tokens >= requested) to allow the operation. If so, it subtracts the requested tokens from the new token count.

  6. The script updates the state in Redis by setting the new token count and new last access time with a TTL of 60 seconds (adjustable based on your requirements).

  7. Finally, the script returns 1 if the operation is allowed or 0 if it is not.

-- Input parameters
local tokens_key = KEYS[1]..":tokens"           -- Key for the bucket's token counter
local last_access_key = KEYS[1]..":last_access" -- Key for the bucket's last access time

local capacity = tonumber(ARGV[1])  -- Maximum number of tokens in the bucket
local rate = tonumber(ARGV[2])      -- Rate of token generation (tokens/second)
local now = tonumber(ARGV[3])       -- Current timestamp in microseconds
local requested = tonumber(ARGV[4]) -- Number of tokens requested for the operation

-- Fetch the current token count
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
    last_tokens = capacity
end

-- Fetch the last access time
local last_access = tonumber(redis.call("get", last_access_key))
if last_access == nil then
    last_access = 0
end

-- Calculate the number of tokens to be added due to the elapsed time since the
-- last access. We cap the number at the capacity of the bucket.
local elapsed = math.max(0, now - last_access)
local add_tokens = math.floor(elapsed * rate / 1000000)
local new_tokens = math.min(capacity, last_tokens + add_tokens)

-- Calculate the new last access time. We don't want to use the current time as
-- the new last access time, because that would result in a rounding error.
local new_access_time = last_access + math.ceil(add_tokens * 1000000 / rate)

-- Check if enough tokens have been accumulated
local allowed = new_tokens >= requested
if allowed then
    new_tokens = new_tokens - requested
end

-- Update state
redis.call("setex", tokens_key, 60, new_tokens)
redis.call("setex", last_access_key, 60, new_access_time)

-- Return 1 if the operation is allowed, 0 otherwise.
return allowed and 1 or 0

Why do we need Lua Scripting(Redis) for our Rate limiting Logic?

The first thing, we need to understand here What is the Redis script is?
In Redis, a script refers to a piece of Lua code that is executed on the server. Redis scripting allows you to execute multiple commands as an atomic operation (Atomicity refers to the property of an operation that ensures it is performed as a single, indivisible unit of work. It guarantees that either the operation completes entirely or has no effect at all, preventing intermediate or partial states. Atomicity ensures data integrity and consistency in concurrent and transactional scenarios), providing transactional semantics. It also enables you to perform complex operations that are not possible with individual Redis commands alone. all server activities are blocked during its entire runtime.

Now we will understand what is the main reason behind writing our rate limit logic in this Redis script. In the case of a Distributed System Redis operations may result in inconsistent data. Suppose Redis has one token left. If two requests are interleaved at the same time, both will be processed. This is not something we want to happen with our rate limiter. When a Redis operation is in progress, we must block a server's execution, but this logic may increase the complexity of our code and make our server very slow. In this case, RedisScripts (the Lua-scripting language) is used to achieve atomicity.