Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 feat: Add Load-Shedding Middleware #3264

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions middleware/loadshedding/loadshedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package loadshedding

import (
"context"
"time"

"github.com/gofiber/fiber/v3"
)

// New creates a middleware handler enforces a timeout on request processing to manage server load.
// If a request exceeds the specified timeout, a custom load-shedding handler is executed.
func New(timeout time.Duration, loadSheddingHandler fiber.Handler, exclude func(fiber.Ctx) bool) fiber.Handler {
return func(c fiber.Ctx) error {
// Skip load-shedding for excluded requests
if exclude != nil && exclude(c) {
return c.Next()
}

// Create a context with the specified timeout
ctx, cancel := context.WithTimeout(c.Context(), timeout)
defer cancel()

// Channel to signal request completion
done := make(chan error, 1)

// Process the handler in a separate goroutine
go func() {
done <- c.Next()
}()

select {
case <-ctx.Done():
// Timeout occurred; invoke the load-shedding handler
return loadSheddingHandler(c)
case err := <-done:
// Request completed successfully; return any handler error
return err
}
}
}
92 changes: 92 additions & 0 deletions middleware/loadshedding/loadshedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package loadshedding_test

import (
"net/http/httptest"
"testing"
"time"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/loadshedding"
"github.com/stretchr/testify/require"
)

// Helper handlers
func successHandler(c fiber.Ctx) error {
return c.SendString("Request processed successfully!")
}

func timeoutHandler(c fiber.Ctx) error {
time.Sleep(2 * time.Second) // Simulate a long-running request
return c.SendString("This should not appear")
}

func loadSheddingHandler(c fiber.Ctx) error {
return c.Status(fiber.StatusServiceUnavailable).SendString("Service Overloaded")
}

func excludedHandler(c fiber.Ctx) error {
return c.SendString("Excluded route")
}

// go test -run Test_LoadSheddingExcluded
func Test_LoadSheddingExcluded(t *testing.T) {
t.Parallel()
app := fiber.New()

// Middleware with exclusion
app.Use(loadshedding.New(
1*time.Second,
loadSheddingHandler,
func(c fiber.Ctx) bool { return c.Path() == "/excluded" },
))
app.Get("/", successHandler)
app.Get("/excluded", excludedHandler)

// Test excluded route
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/excluded", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

// go test -run Test_LoadSheddingTimeout
func Test_LoadSheddingTimeout(t *testing.T) {
t.Parallel()
app := fiber.New()

// Middleware without exclusions
app.Use(loadshedding.New(
1*time.Second, // Set timeout for the middleware
loadSheddingHandler,
nil,
))
app.Get("/", timeoutHandler)

// Create a custom HTTP client with a sufficient timeout

// Test timeout behavior
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req, fiber.TestConfig{
Timeout: 3 * time.Second,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode)
}

// go test -run Test_LoadSheddingSuccessfulRequest
func Test_LoadSheddingSuccessfulRequest(t *testing.T) {
t.Parallel()
app := fiber.New()

// Middleware with sufficient time for request to complete
app.Use(loadshedding.New(
2*time.Second,
loadSheddingHandler,
nil,
))
app.Get("/", successHandler)

// Test successful request
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
Loading