client rate limiting

This commit is contained in:
2025-08-22 17:45:25 +02:00
parent f80c506a95
commit b909c2f2e4
6 changed files with 120 additions and 33 deletions

2
go.sum
View File

@@ -6,3 +6,5 @@ github.com/wneessen/go-fileperm v0.2.1 h1:VNZT41b8HJDY5zUw4TbwPtfU1DuxZ3lcGH4dXl
github.com/wneessen/go-fileperm v0.2.1/go.mod h1:Isv0pfQJstXAlmGGJjLGqCK0Z6d1ehbbrsO2xmTRsKs=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=

23
main.go
View File

@@ -9,11 +9,14 @@ import (
)
const (
HELP_PORT string = "what port miniws will run on"
HELP_LOGFOLDER string = "the logs folder"
HELP_CONFIGFOLDER string = "the configurations folder"
HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve"
HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split"
HELP_PORT string = "what port miniws will run on"
HELP_LOGFOLDER string = "the logs folder"
HELP_CONFIGFOLDER string = "the configurations folder"
HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve"
HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split"
HELP_MAXCLIENTRATE string = "the maximum number of requests per minute that any particular " +
"client can send. exceeding this rate will cause miniws to reply with HTTP error 429: " +
"Too Many Requests."
)
func main() {
@@ -24,6 +27,7 @@ func main() {
configFolder := parser.String("c", "config-folder", &argparse.Options{Default: "config", Help: HELP_CONFIGFOLDER})
wwwFolder := parser.String("w", "www-folder", &argparse.Options{Default: ".", Help: HELP_WWWFOLDER})
maxLogBytes := parser.Int("b", "max-log-bytes", &argparse.Options{Default: 1048576, Help: HELP_MAXLOGBYTES})
maxClientRatePerMin := parser.Int("r", "max-client-rate", &argparse.Options{Default: 600, Help: HELP_MAXCLIENTRATE})
err := parser.Parse(os.Args)
if err != nil {
@@ -33,6 +37,13 @@ func main() {
return
}
webserver := miniws.NewWebServer(*port, *logFolder, *configFolder, *wwwFolder, int64(*maxLogBytes))
webserver := miniws.NewWebServer(miniws.WebServerConfig{
LogFolder: *logFolder,
ConfigFolder: *configFolder,
WWWFolder: *wwwFolder,
Port: uint16(*port),
MaxBytesPerLogFile: uint64(*maxLogBytes),
MaxConnectionsPerMinute: uint64(*maxClientRatePerMin),
})
webserver.Run()
}

View File

@@ -0,0 +1,51 @@
package miniws
import (
"log"
"maps"
"golang.org/x/time/rate"
)
const CRL_MAX_CLIENT_REMEMBER_SEC float64 = 60.
type clientRateLimiter struct {
limits map[string]*rate.Limiter
maxConnsPerSec float64
}
func newClientRateLimiter(maxConnectionsPerMin float64) *clientRateLimiter {
return &clientRateLimiter{
limits: make(map[string]*rate.Limiter),
maxConnsPerSec: maxConnectionsPerMin / 60.,
}
}
func (crl *clientRateLimiter) canConnect(clientIp string) bool {
limiter, ok := crl.limits[clientIp]
if !ok {
crl.limits[clientIp] = rate.NewLimiter(rate.Limit(crl.maxConnsPerSec), 1)
limiter = crl.limits[clientIp]
log.Println("new client: " + clientIp)
}
allowed := limiter.Allow()
if allowed {
log.Println("client " + clientIp + " was allowed")
} else {
log.Println("client " + clientIp + " has been rate limited")
}
crl._cleanup()
return allowed
}
func (crl *clientRateLimiter) _cleanup() {
maps.DeleteFunc(crl.limits, func(clIp string, limit *rate.Limiter) bool {
// forget the client if they haven't connected in CRL_MAX_CLIENT_REMEMBER_SEC seconds
forgetClient := limit.Tokens() >= CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec
if forgetClient {
log.Println("Forgetting client "+clIp+": accumulated tokens ", limit.Tokens(),
"exceed limit of ", CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec)
}
return forgetClient
})
}

View File

@@ -1,3 +1,7 @@
module github.com/shlldev/miniws/miniws
go 1.22.2
require (
golang.org/x/time v0.12.0 // indirect
)

View File

@@ -18,10 +18,10 @@ const (
type Logger struct {
logFolder string
maxLogBytes int64
maxLogBytes uint64
}
func NewLogger(logFolder_ string, maxLogBytes_ int64) *Logger {
func NewLogger(logFolder_ string, maxLogBytes_ uint64) *Logger {
return &Logger{
logFolder: logFolder_,
maxLogBytes: maxLogBytes_,
@@ -70,7 +70,7 @@ func (l *Logger) writeToLogFileAndRenameIfBig(fileName, content string) {
return
}
if fileinfo.Size() > l.maxLogBytes {
if uint64(fileinfo.Size()) > l.maxLogBytes {
var renamedFiledPath string = fullPath + "." + uuid.NewString()

View File

@@ -4,6 +4,7 @@ import (
"errors"
"log"
"mime"
"net"
"net/http"
"os"
"path/filepath"
@@ -22,40 +23,46 @@ const (
type FilterMode int
type WebServerConfig struct {
LogFolder string
ConfigFolder string
WWWFolder string
Port uint16
MaxBytesPerLogFile uint64
MaxConnectionsPerMinute uint64
}
type WebServer struct {
logger *Logger
port int
configFolder string
wwwFolder string
cfg WebServerConfig
ipFilter []string
userAgentFilter []string
ipFilterMode FilterMode
userAgentFilterMode FilterMode
clientLimiter *clientRateLimiter
}
func NewWebServer(port_ int, logFolder_, configFolder_, wwwFolder_ string, maxLogBytes_ int64) *WebServer {
func NewWebServer(cfg WebServerConfig) *WebServer {
return &WebServer{
logger: NewLogger(logFolder_, maxLogBytes_),
port: port_,
configFolder: configFolder_,
wwwFolder: wwwFolder_,
logger: NewLogger(cfg.LogFolder, cfg.MaxBytesPerLogFile),
cfg: cfg,
ipFilter: make([]string, 0),
userAgentFilter: make([]string, 0),
ipFilterMode: FILTER_MODE_BLACKLIST,
userAgentFilterMode: FILTER_MODE_BLACKLIST,
clientLimiter: newClientRateLimiter(float64(cfg.MaxConnectionsPerMinute)),
}
}
func (ws *WebServer) Run() {
_, err := os.Lstat(ws.wwwFolder)
_, err := os.Lstat(ws.cfg.WWWFolder)
if errors.Is(err, os.ErrNotExist) {
log.Fatalln("Fatal: www folder " + ws.wwwFolder + " does not exist")
log.Fatalln("Fatal: www folder " + ws.cfg.WWWFolder + " does not exist")
} else if err != nil {
log.Fatalln("Fatal: " + err.Error())
}
perms, err := fileperm.New(ws.wwwFolder)
perms, err := fileperm.New(ws.cfg.WWWFolder)
if err != nil {
log.Fatalln("Fatal: " + err.Error())
}
@@ -67,8 +74,9 @@ func (ws *WebServer) Run() {
ws.userAgentFilterMode, ws.userAgentFilter = ws.parseFilterPanics(FILENAME_USERAGENTFILTER)
http.HandleFunc("/", ws.get)
log.Println("Server started on port " + strconv.Itoa(ws.port))
http.ListenAndServe(":"+strconv.Itoa(ws.port), nil)
portStr := strconv.FormatUint(uint64(ws.cfg.Port), 10)
log.Println("Server started on port " + portStr)
http.ListenAndServe(":"+portStr, nil)
}
func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
@@ -76,10 +84,10 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
filterMode := FILTER_MODE_BLACKLIST
filter := make([]string, 0)
os.Mkdir(ws.configFolder, PERMS_MKDIR)
fileinfo, err := os.Stat(filepath.Join(ws.configFolder, fileName))
os.Mkdir(ws.cfg.ConfigFolder, PERMS_MKDIR)
fileinfo, err := os.Stat(filepath.Join(ws.cfg.ConfigFolder, fileName))
fullPath := filepath.Join(ws.configFolder, fileName)
fullPath := filepath.Join(ws.cfg.ConfigFolder, fileName)
if errors.Is(err, os.ErrNotExist) {
os.Create(fullPath)
fileinfo, err = os.Stat(fullPath)
@@ -126,19 +134,23 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
}
func (ws *WebServer) isIpValid(ip string) bool {
ip = strings.Split(ip, ":")[0] // remove port
func (ws *WebServer) isIpAllowed(ip string) bool {
hostIp, _, err := net.SplitHostPort(ip)
if err != nil && !strings.Contains(ip, ":") {
log.Println(err)
return false
}
switch ws.ipFilterMode {
case FILTER_MODE_WHITELIST:
return slices.Contains(ws.ipFilter, ip)
return slices.Contains(ws.ipFilter, hostIp)
case FILTER_MODE_BLACKLIST:
return !slices.Contains(ws.ipFilter, ip)
return !slices.Contains(ws.ipFilter, hostIp)
default:
return false //if something went wrong with conf parsing
}
}
func (ws *WebServer) isUserAgentValid(userAgent string) bool {
func (ws *WebServer) isUserAgentAllowed(userAgent string) bool {
switch ws.userAgentFilterMode {
case FILTER_MODE_WHITELIST:
for _, userAgentFiltered := range ws.userAgentFilter {
@@ -200,12 +212,19 @@ func (ws *WebServer) get(writer http.ResponseWriter, req *http.Request) {
respStatusCode := http.StatusOK
if !ws.isIpValid(req.RemoteAddr) || !ws.isUserAgentValid(req.UserAgent()) {
// check that IP and User Agent of client are whitelisted, or not blacklisted
if !ws.isIpAllowed(req.RemoteAddr) || !ws.isUserAgentAllowed(req.UserAgent()) {
writer.WriteHeader(http.StatusForbidden)
return
}
// check that the client IP has not been sending too many requests recently
if !ws.clientLimiter.canConnect(req.RemoteAddr) {
writer.WriteHeader(http.StatusTooManyRequests)
return
}
fileToFetch := filepath.Join(ws.wwwFolder, req.URL.Path)
fileToFetch := filepath.Join(ws.cfg.WWWFolder, req.URL.Path)
fetchedFile, fetchErr := ws.fetchFile(fileToFetch)
fetchedFileStat, _ := fetchedFile.Stat()
fetchedStat, _ := ws.fetchStat(fileToFetch)