client rate limiting

This commit is contained in:
2025-08-22 17:45:25 +02:00
parent f80c506a95
commit b909c2f2e4
6 changed files with 120 additions and 33 deletions

2
go.sum
View File

@@ -6,3 +6,5 @@ github.com/wneessen/go-fileperm v0.2.1 h1:VNZT41b8HJDY5zUw4TbwPtfU1DuxZ3lcGH4dXl
github.com/wneessen/go-fileperm v0.2.1/go.mod h1:Isv0pfQJstXAlmGGJjLGqCK0Z6d1ehbbrsO2xmTRsKs= github.com/wneessen/go-fileperm v0.2.1/go.mod h1:Isv0pfQJstXAlmGGJjLGqCK0Z6d1ehbbrsO2xmTRsKs=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=

23
main.go
View File

@@ -9,11 +9,14 @@ import (
) )
const ( const (
HELP_PORT string = "what port miniws will run on" HELP_PORT string = "what port miniws will run on"
HELP_LOGFOLDER string = "the logs folder" HELP_LOGFOLDER string = "the logs folder"
HELP_CONFIGFOLDER string = "the configurations folder" HELP_CONFIGFOLDER string = "the configurations folder"
HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve" HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve"
HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split" HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split"
HELP_MAXCLIENTRATE string = "the maximum number of requests per minute that any particular " +
"client can send. exceeding this rate will cause miniws to reply with HTTP error 429: " +
"Too Many Requests."
) )
func main() { func main() {
@@ -24,6 +27,7 @@ func main() {
configFolder := parser.String("c", "config-folder", &argparse.Options{Default: "config", Help: HELP_CONFIGFOLDER}) configFolder := parser.String("c", "config-folder", &argparse.Options{Default: "config", Help: HELP_CONFIGFOLDER})
wwwFolder := parser.String("w", "www-folder", &argparse.Options{Default: ".", Help: HELP_WWWFOLDER}) wwwFolder := parser.String("w", "www-folder", &argparse.Options{Default: ".", Help: HELP_WWWFOLDER})
maxLogBytes := parser.Int("b", "max-log-bytes", &argparse.Options{Default: 1048576, Help: HELP_MAXLOGBYTES}) maxLogBytes := parser.Int("b", "max-log-bytes", &argparse.Options{Default: 1048576, Help: HELP_MAXLOGBYTES})
maxClientRatePerMin := parser.Int("r", "max-client-rate", &argparse.Options{Default: 600, Help: HELP_MAXCLIENTRATE})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@@ -33,6 +37,13 @@ func main() {
return return
} }
webserver := miniws.NewWebServer(*port, *logFolder, *configFolder, *wwwFolder, int64(*maxLogBytes)) webserver := miniws.NewWebServer(miniws.WebServerConfig{
LogFolder: *logFolder,
ConfigFolder: *configFolder,
WWWFolder: *wwwFolder,
Port: uint16(*port),
MaxBytesPerLogFile: uint64(*maxLogBytes),
MaxConnectionsPerMinute: uint64(*maxClientRatePerMin),
})
webserver.Run() webserver.Run()
} }

View File

@@ -0,0 +1,51 @@
package miniws
import (
"log"
"maps"
"golang.org/x/time/rate"
)
const CRL_MAX_CLIENT_REMEMBER_SEC float64 = 60.
type clientRateLimiter struct {
limits map[string]*rate.Limiter
maxConnsPerSec float64
}
func newClientRateLimiter(maxConnectionsPerMin float64) *clientRateLimiter {
return &clientRateLimiter{
limits: make(map[string]*rate.Limiter),
maxConnsPerSec: maxConnectionsPerMin / 60.,
}
}
func (crl *clientRateLimiter) canConnect(clientIp string) bool {
limiter, ok := crl.limits[clientIp]
if !ok {
crl.limits[clientIp] = rate.NewLimiter(rate.Limit(crl.maxConnsPerSec), 1)
limiter = crl.limits[clientIp]
log.Println("new client: " + clientIp)
}
allowed := limiter.Allow()
if allowed {
log.Println("client " + clientIp + " was allowed")
} else {
log.Println("client " + clientIp + " has been rate limited")
}
crl._cleanup()
return allowed
}
func (crl *clientRateLimiter) _cleanup() {
maps.DeleteFunc(crl.limits, func(clIp string, limit *rate.Limiter) bool {
// forget the client if they haven't connected in CRL_MAX_CLIENT_REMEMBER_SEC seconds
forgetClient := limit.Tokens() >= CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec
if forgetClient {
log.Println("Forgetting client "+clIp+": accumulated tokens ", limit.Tokens(),
"exceed limit of ", CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec)
}
return forgetClient
})
}

View File

@@ -1,3 +1,7 @@
module github.com/shlldev/miniws/miniws module github.com/shlldev/miniws/miniws
go 1.22.2 go 1.22.2
require (
golang.org/x/time v0.12.0 // indirect
)

View File

@@ -18,10 +18,10 @@ const (
type Logger struct { type Logger struct {
logFolder string logFolder string
maxLogBytes int64 maxLogBytes uint64
} }
func NewLogger(logFolder_ string, maxLogBytes_ int64) *Logger { func NewLogger(logFolder_ string, maxLogBytes_ uint64) *Logger {
return &Logger{ return &Logger{
logFolder: logFolder_, logFolder: logFolder_,
maxLogBytes: maxLogBytes_, maxLogBytes: maxLogBytes_,
@@ -70,7 +70,7 @@ func (l *Logger) writeToLogFileAndRenameIfBig(fileName, content string) {
return return
} }
if fileinfo.Size() > l.maxLogBytes { if uint64(fileinfo.Size()) > l.maxLogBytes {
var renamedFiledPath string = fullPath + "." + uuid.NewString() var renamedFiledPath string = fullPath + "." + uuid.NewString()

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"log" "log"
"mime" "mime"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@@ -22,40 +23,46 @@ const (
type FilterMode int type FilterMode int
type WebServerConfig struct {
LogFolder string
ConfigFolder string
WWWFolder string
Port uint16
MaxBytesPerLogFile uint64
MaxConnectionsPerMinute uint64
}
type WebServer struct { type WebServer struct {
logger *Logger logger *Logger
port int cfg WebServerConfig
configFolder string
wwwFolder string
ipFilter []string ipFilter []string
userAgentFilter []string userAgentFilter []string
ipFilterMode FilterMode ipFilterMode FilterMode
userAgentFilterMode FilterMode userAgentFilterMode FilterMode
clientLimiter *clientRateLimiter
} }
func NewWebServer(port_ int, logFolder_, configFolder_, wwwFolder_ string, maxLogBytes_ int64) *WebServer { func NewWebServer(cfg WebServerConfig) *WebServer {
return &WebServer{ return &WebServer{
logger: NewLogger(logFolder_, maxLogBytes_), logger: NewLogger(cfg.LogFolder, cfg.MaxBytesPerLogFile),
port: port_, cfg: cfg,
configFolder: configFolder_,
wwwFolder: wwwFolder_,
ipFilter: make([]string, 0), ipFilter: make([]string, 0),
userAgentFilter: make([]string, 0), userAgentFilter: make([]string, 0),
ipFilterMode: FILTER_MODE_BLACKLIST, ipFilterMode: FILTER_MODE_BLACKLIST,
userAgentFilterMode: FILTER_MODE_BLACKLIST, userAgentFilterMode: FILTER_MODE_BLACKLIST,
clientLimiter: newClientRateLimiter(float64(cfg.MaxConnectionsPerMinute)),
} }
} }
func (ws *WebServer) Run() { func (ws *WebServer) Run() {
_, err := os.Lstat(ws.wwwFolder) _, err := os.Lstat(ws.cfg.WWWFolder)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
log.Fatalln("Fatal: www folder " + ws.wwwFolder + " does not exist") log.Fatalln("Fatal: www folder " + ws.cfg.WWWFolder + " does not exist")
} else if err != nil { } else if err != nil {
log.Fatalln("Fatal: " + err.Error()) log.Fatalln("Fatal: " + err.Error())
} }
perms, err := fileperm.New(ws.wwwFolder) perms, err := fileperm.New(ws.cfg.WWWFolder)
if err != nil { if err != nil {
log.Fatalln("Fatal: " + err.Error()) log.Fatalln("Fatal: " + err.Error())
} }
@@ -67,8 +74,9 @@ func (ws *WebServer) Run() {
ws.userAgentFilterMode, ws.userAgentFilter = ws.parseFilterPanics(FILENAME_USERAGENTFILTER) ws.userAgentFilterMode, ws.userAgentFilter = ws.parseFilterPanics(FILENAME_USERAGENTFILTER)
http.HandleFunc("/", ws.get) http.HandleFunc("/", ws.get)
log.Println("Server started on port " + strconv.Itoa(ws.port)) portStr := strconv.FormatUint(uint64(ws.cfg.Port), 10)
http.ListenAndServe(":"+strconv.Itoa(ws.port), nil) log.Println("Server started on port " + portStr)
http.ListenAndServe(":"+portStr, nil)
} }
func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) { func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
@@ -76,10 +84,10 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
filterMode := FILTER_MODE_BLACKLIST filterMode := FILTER_MODE_BLACKLIST
filter := make([]string, 0) filter := make([]string, 0)
os.Mkdir(ws.configFolder, PERMS_MKDIR) os.Mkdir(ws.cfg.ConfigFolder, PERMS_MKDIR)
fileinfo, err := os.Stat(filepath.Join(ws.configFolder, fileName)) fileinfo, err := os.Stat(filepath.Join(ws.cfg.ConfigFolder, fileName))
fullPath := filepath.Join(ws.configFolder, fileName) fullPath := filepath.Join(ws.cfg.ConfigFolder, fileName)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
os.Create(fullPath) os.Create(fullPath)
fileinfo, err = os.Stat(fullPath) fileinfo, err = os.Stat(fullPath)
@@ -126,19 +134,23 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
} }
func (ws *WebServer) isIpValid(ip string) bool { func (ws *WebServer) isIpAllowed(ip string) bool {
ip = strings.Split(ip, ":")[0] // remove port hostIp, _, err := net.SplitHostPort(ip)
if err != nil && !strings.Contains(ip, ":") {
log.Println(err)
return false
}
switch ws.ipFilterMode { switch ws.ipFilterMode {
case FILTER_MODE_WHITELIST: case FILTER_MODE_WHITELIST:
return slices.Contains(ws.ipFilter, ip) return slices.Contains(ws.ipFilter, hostIp)
case FILTER_MODE_BLACKLIST: case FILTER_MODE_BLACKLIST:
return !slices.Contains(ws.ipFilter, ip) return !slices.Contains(ws.ipFilter, hostIp)
default: default:
return false //if something went wrong with conf parsing return false //if something went wrong with conf parsing
} }
} }
func (ws *WebServer) isUserAgentValid(userAgent string) bool { func (ws *WebServer) isUserAgentAllowed(userAgent string) bool {
switch ws.userAgentFilterMode { switch ws.userAgentFilterMode {
case FILTER_MODE_WHITELIST: case FILTER_MODE_WHITELIST:
for _, userAgentFiltered := range ws.userAgentFilter { for _, userAgentFiltered := range ws.userAgentFilter {
@@ -200,12 +212,19 @@ func (ws *WebServer) get(writer http.ResponseWriter, req *http.Request) {
respStatusCode := http.StatusOK respStatusCode := http.StatusOK
if !ws.isIpValid(req.RemoteAddr) || !ws.isUserAgentValid(req.UserAgent()) { // check that IP and User Agent of client are whitelisted, or not blacklisted
if !ws.isIpAllowed(req.RemoteAddr) || !ws.isUserAgentAllowed(req.UserAgent()) {
writer.WriteHeader(http.StatusForbidden) writer.WriteHeader(http.StatusForbidden)
return return
} }
// check that the client IP has not been sending too many requests recently
if !ws.clientLimiter.canConnect(req.RemoteAddr) {
writer.WriteHeader(http.StatusTooManyRequests)
return
}
fileToFetch := filepath.Join(ws.wwwFolder, req.URL.Path) fileToFetch := filepath.Join(ws.cfg.WWWFolder, req.URL.Path)
fetchedFile, fetchErr := ws.fetchFile(fileToFetch) fetchedFile, fetchErr := ws.fetchFile(fileToFetch)
fetchedFileStat, _ := fetchedFile.Stat() fetchedFileStat, _ := fetchedFile.Stat()
fetchedStat, _ := ws.fetchStat(fileToFetch) fetchedStat, _ := ws.fetchStat(fileToFetch)