mirror of
https://github.com/shlldev/miniws.git
synced 2025-09-02 19:00:59 +02:00
client rate limiting
This commit is contained in:
2
go.sum
2
go.sum
@@ -6,3 +6,5 @@ github.com/wneessen/go-fileperm v0.2.1 h1:VNZT41b8HJDY5zUw4TbwPtfU1DuxZ3lcGH4dXl
|
|||||||
github.com/wneessen/go-fileperm v0.2.1/go.mod h1:Isv0pfQJstXAlmGGJjLGqCK0Z6d1ehbbrsO2xmTRsKs=
|
github.com/wneessen/go-fileperm v0.2.1/go.mod h1:Isv0pfQJstXAlmGGJjLGqCK0Z6d1ehbbrsO2xmTRsKs=
|
||||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
|
23
main.go
23
main.go
@@ -9,11 +9,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HELP_PORT string = "what port miniws will run on"
|
HELP_PORT string = "what port miniws will run on"
|
||||||
HELP_LOGFOLDER string = "the logs folder"
|
HELP_LOGFOLDER string = "the logs folder"
|
||||||
HELP_CONFIGFOLDER string = "the configurations folder"
|
HELP_CONFIGFOLDER string = "the configurations folder"
|
||||||
HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve"
|
HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve"
|
||||||
HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split"
|
HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split"
|
||||||
|
HELP_MAXCLIENTRATE string = "the maximum number of requests per minute that any particular " +
|
||||||
|
"client can send. exceeding this rate will cause miniws to reply with HTTP error 429: " +
|
||||||
|
"Too Many Requests."
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -24,6 +27,7 @@ func main() {
|
|||||||
configFolder := parser.String("c", "config-folder", &argparse.Options{Default: "config", Help: HELP_CONFIGFOLDER})
|
configFolder := parser.String("c", "config-folder", &argparse.Options{Default: "config", Help: HELP_CONFIGFOLDER})
|
||||||
wwwFolder := parser.String("w", "www-folder", &argparse.Options{Default: ".", Help: HELP_WWWFOLDER})
|
wwwFolder := parser.String("w", "www-folder", &argparse.Options{Default: ".", Help: HELP_WWWFOLDER})
|
||||||
maxLogBytes := parser.Int("b", "max-log-bytes", &argparse.Options{Default: 1048576, Help: HELP_MAXLOGBYTES})
|
maxLogBytes := parser.Int("b", "max-log-bytes", &argparse.Options{Default: 1048576, Help: HELP_MAXLOGBYTES})
|
||||||
|
maxClientRatePerMin := parser.Int("r", "max-client-rate", &argparse.Options{Default: 600, Help: HELP_MAXCLIENTRATE})
|
||||||
|
|
||||||
err := parser.Parse(os.Args)
|
err := parser.Parse(os.Args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -33,6 +37,13 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
webserver := miniws.NewWebServer(*port, *logFolder, *configFolder, *wwwFolder, int64(*maxLogBytes))
|
webserver := miniws.NewWebServer(miniws.WebServerConfig{
|
||||||
|
LogFolder: *logFolder,
|
||||||
|
ConfigFolder: *configFolder,
|
||||||
|
WWWFolder: *wwwFolder,
|
||||||
|
Port: uint16(*port),
|
||||||
|
MaxBytesPerLogFile: uint64(*maxLogBytes),
|
||||||
|
MaxConnectionsPerMinute: uint64(*maxClientRatePerMin),
|
||||||
|
})
|
||||||
webserver.Run()
|
webserver.Run()
|
||||||
}
|
}
|
||||||
|
51
miniws/client_ratelimit.go
Normal file
51
miniws/client_ratelimit.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package miniws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"maps"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
const CRL_MAX_CLIENT_REMEMBER_SEC float64 = 60.
|
||||||
|
|
||||||
|
type clientRateLimiter struct {
|
||||||
|
limits map[string]*rate.Limiter
|
||||||
|
maxConnsPerSec float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientRateLimiter(maxConnectionsPerMin float64) *clientRateLimiter {
|
||||||
|
return &clientRateLimiter{
|
||||||
|
limits: make(map[string]*rate.Limiter),
|
||||||
|
maxConnsPerSec: maxConnectionsPerMin / 60.,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (crl *clientRateLimiter) canConnect(clientIp string) bool {
|
||||||
|
limiter, ok := crl.limits[clientIp]
|
||||||
|
if !ok {
|
||||||
|
crl.limits[clientIp] = rate.NewLimiter(rate.Limit(crl.maxConnsPerSec), 1)
|
||||||
|
limiter = crl.limits[clientIp]
|
||||||
|
log.Println("new client: " + clientIp)
|
||||||
|
}
|
||||||
|
allowed := limiter.Allow()
|
||||||
|
if allowed {
|
||||||
|
log.Println("client " + clientIp + " was allowed")
|
||||||
|
} else {
|
||||||
|
log.Println("client " + clientIp + " has been rate limited")
|
||||||
|
}
|
||||||
|
crl._cleanup()
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (crl *clientRateLimiter) _cleanup() {
|
||||||
|
maps.DeleteFunc(crl.limits, func(clIp string, limit *rate.Limiter) bool {
|
||||||
|
// forget the client if they haven't connected in CRL_MAX_CLIENT_REMEMBER_SEC seconds
|
||||||
|
forgetClient := limit.Tokens() >= CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec
|
||||||
|
if forgetClient {
|
||||||
|
log.Println("Forgetting client "+clIp+": accumulated tokens ", limit.Tokens(),
|
||||||
|
"exceed limit of ", CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec)
|
||||||
|
}
|
||||||
|
return forgetClient
|
||||||
|
})
|
||||||
|
}
|
@@ -1,3 +1,7 @@
|
|||||||
module github.com/shlldev/miniws/miniws
|
module github.com/shlldev/miniws/miniws
|
||||||
|
|
||||||
go 1.22.2
|
go 1.22.2
|
||||||
|
|
||||||
|
require (
|
||||||
|
golang.org/x/time v0.12.0 // indirect
|
||||||
|
)
|
@@ -18,10 +18,10 @@ const (
|
|||||||
|
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
logFolder string
|
logFolder string
|
||||||
maxLogBytes int64
|
maxLogBytes uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLogger(logFolder_ string, maxLogBytes_ int64) *Logger {
|
func NewLogger(logFolder_ string, maxLogBytes_ uint64) *Logger {
|
||||||
return &Logger{
|
return &Logger{
|
||||||
logFolder: logFolder_,
|
logFolder: logFolder_,
|
||||||
maxLogBytes: maxLogBytes_,
|
maxLogBytes: maxLogBytes_,
|
||||||
@@ -70,7 +70,7 @@ func (l *Logger) writeToLogFileAndRenameIfBig(fileName, content string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if fileinfo.Size() > l.maxLogBytes {
|
if uint64(fileinfo.Size()) > l.maxLogBytes {
|
||||||
|
|
||||||
var renamedFiledPath string = fullPath + "." + uuid.NewString()
|
var renamedFiledPath string = fullPath + "." + uuid.NewString()
|
||||||
|
|
||||||
|
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"mime"
|
"mime"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -22,40 +23,46 @@ const (
|
|||||||
|
|
||||||
type FilterMode int
|
type FilterMode int
|
||||||
|
|
||||||
|
type WebServerConfig struct {
|
||||||
|
LogFolder string
|
||||||
|
ConfigFolder string
|
||||||
|
WWWFolder string
|
||||||
|
Port uint16
|
||||||
|
MaxBytesPerLogFile uint64
|
||||||
|
MaxConnectionsPerMinute uint64
|
||||||
|
}
|
||||||
|
|
||||||
type WebServer struct {
|
type WebServer struct {
|
||||||
logger *Logger
|
logger *Logger
|
||||||
port int
|
cfg WebServerConfig
|
||||||
configFolder string
|
|
||||||
wwwFolder string
|
|
||||||
ipFilter []string
|
ipFilter []string
|
||||||
userAgentFilter []string
|
userAgentFilter []string
|
||||||
ipFilterMode FilterMode
|
ipFilterMode FilterMode
|
||||||
userAgentFilterMode FilterMode
|
userAgentFilterMode FilterMode
|
||||||
|
clientLimiter *clientRateLimiter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWebServer(port_ int, logFolder_, configFolder_, wwwFolder_ string, maxLogBytes_ int64) *WebServer {
|
func NewWebServer(cfg WebServerConfig) *WebServer {
|
||||||
|
|
||||||
return &WebServer{
|
return &WebServer{
|
||||||
logger: NewLogger(logFolder_, maxLogBytes_),
|
logger: NewLogger(cfg.LogFolder, cfg.MaxBytesPerLogFile),
|
||||||
port: port_,
|
cfg: cfg,
|
||||||
configFolder: configFolder_,
|
|
||||||
wwwFolder: wwwFolder_,
|
|
||||||
ipFilter: make([]string, 0),
|
ipFilter: make([]string, 0),
|
||||||
userAgentFilter: make([]string, 0),
|
userAgentFilter: make([]string, 0),
|
||||||
ipFilterMode: FILTER_MODE_BLACKLIST,
|
ipFilterMode: FILTER_MODE_BLACKLIST,
|
||||||
userAgentFilterMode: FILTER_MODE_BLACKLIST,
|
userAgentFilterMode: FILTER_MODE_BLACKLIST,
|
||||||
|
clientLimiter: newClientRateLimiter(float64(cfg.MaxConnectionsPerMinute)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) Run() {
|
func (ws *WebServer) Run() {
|
||||||
|
|
||||||
_, err := os.Lstat(ws.wwwFolder)
|
_, err := os.Lstat(ws.cfg.WWWFolder)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
log.Fatalln("Fatal: www folder " + ws.wwwFolder + " does not exist")
|
log.Fatalln("Fatal: www folder " + ws.cfg.WWWFolder + " does not exist")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.Fatalln("Fatal: " + err.Error())
|
log.Fatalln("Fatal: " + err.Error())
|
||||||
}
|
}
|
||||||
perms, err := fileperm.New(ws.wwwFolder)
|
perms, err := fileperm.New(ws.cfg.WWWFolder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln("Fatal: " + err.Error())
|
log.Fatalln("Fatal: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -67,8 +74,9 @@ func (ws *WebServer) Run() {
|
|||||||
ws.userAgentFilterMode, ws.userAgentFilter = ws.parseFilterPanics(FILENAME_USERAGENTFILTER)
|
ws.userAgentFilterMode, ws.userAgentFilter = ws.parseFilterPanics(FILENAME_USERAGENTFILTER)
|
||||||
|
|
||||||
http.HandleFunc("/", ws.get)
|
http.HandleFunc("/", ws.get)
|
||||||
log.Println("Server started on port " + strconv.Itoa(ws.port))
|
portStr := strconv.FormatUint(uint64(ws.cfg.Port), 10)
|
||||||
http.ListenAndServe(":"+strconv.Itoa(ws.port), nil)
|
log.Println("Server started on port " + portStr)
|
||||||
|
http.ListenAndServe(":"+portStr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
|
func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
|
||||||
@@ -76,10 +84,10 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
|
|||||||
filterMode := FILTER_MODE_BLACKLIST
|
filterMode := FILTER_MODE_BLACKLIST
|
||||||
filter := make([]string, 0)
|
filter := make([]string, 0)
|
||||||
|
|
||||||
os.Mkdir(ws.configFolder, PERMS_MKDIR)
|
os.Mkdir(ws.cfg.ConfigFolder, PERMS_MKDIR)
|
||||||
fileinfo, err := os.Stat(filepath.Join(ws.configFolder, fileName))
|
fileinfo, err := os.Stat(filepath.Join(ws.cfg.ConfigFolder, fileName))
|
||||||
|
|
||||||
fullPath := filepath.Join(ws.configFolder, fileName)
|
fullPath := filepath.Join(ws.cfg.ConfigFolder, fileName)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
os.Create(fullPath)
|
os.Create(fullPath)
|
||||||
fileinfo, err = os.Stat(fullPath)
|
fileinfo, err = os.Stat(fullPath)
|
||||||
@@ -126,19 +134,23 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) isIpValid(ip string) bool {
|
func (ws *WebServer) isIpAllowed(ip string) bool {
|
||||||
ip = strings.Split(ip, ":")[0] // remove port
|
hostIp, _, err := net.SplitHostPort(ip)
|
||||||
|
if err != nil && !strings.Contains(ip, ":") {
|
||||||
|
log.Println(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
switch ws.ipFilterMode {
|
switch ws.ipFilterMode {
|
||||||
case FILTER_MODE_WHITELIST:
|
case FILTER_MODE_WHITELIST:
|
||||||
return slices.Contains(ws.ipFilter, ip)
|
return slices.Contains(ws.ipFilter, hostIp)
|
||||||
case FILTER_MODE_BLACKLIST:
|
case FILTER_MODE_BLACKLIST:
|
||||||
return !slices.Contains(ws.ipFilter, ip)
|
return !slices.Contains(ws.ipFilter, hostIp)
|
||||||
default:
|
default:
|
||||||
return false //if something went wrong with conf parsing
|
return false //if something went wrong with conf parsing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) isUserAgentValid(userAgent string) bool {
|
func (ws *WebServer) isUserAgentAllowed(userAgent string) bool {
|
||||||
switch ws.userAgentFilterMode {
|
switch ws.userAgentFilterMode {
|
||||||
case FILTER_MODE_WHITELIST:
|
case FILTER_MODE_WHITELIST:
|
||||||
for _, userAgentFiltered := range ws.userAgentFilter {
|
for _, userAgentFiltered := range ws.userAgentFilter {
|
||||||
@@ -200,12 +212,19 @@ func (ws *WebServer) get(writer http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
respStatusCode := http.StatusOK
|
respStatusCode := http.StatusOK
|
||||||
|
|
||||||
if !ws.isIpValid(req.RemoteAddr) || !ws.isUserAgentValid(req.UserAgent()) {
|
// check that IP and User Agent of client are whitelisted, or not blacklisted
|
||||||
|
if !ws.isIpAllowed(req.RemoteAddr) || !ws.isUserAgentAllowed(req.UserAgent()) {
|
||||||
|
|
||||||
writer.WriteHeader(http.StatusForbidden)
|
writer.WriteHeader(http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// check that the client IP has not been sending too many requests recently
|
||||||
|
if !ws.clientLimiter.canConnect(req.RemoteAddr) {
|
||||||
|
writer.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
fileToFetch := filepath.Join(ws.wwwFolder, req.URL.Path)
|
fileToFetch := filepath.Join(ws.cfg.WWWFolder, req.URL.Path)
|
||||||
fetchedFile, fetchErr := ws.fetchFile(fileToFetch)
|
fetchedFile, fetchErr := ws.fetchFile(fileToFetch)
|
||||||
fetchedFileStat, _ := fetchedFile.Stat()
|
fetchedFileStat, _ := fetchedFile.Stat()
|
||||||
fetchedStat, _ := ws.fetchStat(fileToFetch)
|
fetchedStat, _ := ws.fetchStat(fileToFetch)
|
||||||
|
Reference in New Issue
Block a user