diff --git a/go.sum b/go.sum index fe202c5..db659b9 100644 --- a/go.sum +++ b/go.sum @@ -6,3 +6,5 @@ github.com/wneessen/go-fileperm v0.2.1 h1:VNZT41b8HJDY5zUw4TbwPtfU1DuxZ3lcGH4dXl github.com/wneessen/go-fileperm v0.2.1/go.mod h1:Isv0pfQJstXAlmGGJjLGqCK0Z6d1ehbbrsO2xmTRsKs= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= diff --git a/main.go b/main.go index 5bf0e05..c9d5eb1 100644 --- a/main.go +++ b/main.go @@ -9,11 +9,14 @@ import ( ) const ( - HELP_PORT string = "what port miniws will run on" - HELP_LOGFOLDER string = "the logs folder" - HELP_CONFIGFOLDER string = "the configurations folder" - HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve" - HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split" + HELP_PORT string = "what port miniws will run on" + HELP_LOGFOLDER string = "the logs folder" + HELP_CONFIGFOLDER string = "the configurations folder" + HELP_WWWFOLDER string = "the www folder where miniws will look for files to serve" + HELP_MAXLOGBYTES string = "the maximum bytes after which the log files get split" + HELP_MAXCLIENTRATE string = "the maximum number of requests per minute that any particular " + + "client can send. exceeding this rate will cause miniws to reply with HTTP error 429: " + + "Too Many Requests." ) func main() { @@ -24,6 +27,7 @@ func main() { configFolder := parser.String("c", "config-folder", &argparse.Options{Default: "config", Help: HELP_CONFIGFOLDER}) wwwFolder := parser.String("w", "www-folder", &argparse.Options{Default: ".", Help: HELP_WWWFOLDER}) maxLogBytes := parser.Int("b", "max-log-bytes", &argparse.Options{Default: 1048576, Help: HELP_MAXLOGBYTES}) + maxClientRatePerMin := parser.Int("r", "max-client-rate", &argparse.Options{Default: 600, Help: HELP_MAXCLIENTRATE}) err := parser.Parse(os.Args) if err != nil { @@ -33,6 +37,13 @@ func main() { return } - webserver := miniws.NewWebServer(*port, *logFolder, *configFolder, *wwwFolder, int64(*maxLogBytes)) + webserver := miniws.NewWebServer(miniws.WebServerConfig{ + LogFolder: *logFolder, + ConfigFolder: *configFolder, + WWWFolder: *wwwFolder, + Port: uint16(*port), + MaxBytesPerLogFile: uint64(*maxLogBytes), + MaxConnectionsPerMinute: uint64(*maxClientRatePerMin), + }) webserver.Run() } diff --git a/miniws/client_ratelimit.go b/miniws/client_ratelimit.go new file mode 100644 index 0000000..5d4dd18 --- /dev/null +++ b/miniws/client_ratelimit.go @@ -0,0 +1,51 @@ +package miniws + +import ( + "log" + "maps" + + "golang.org/x/time/rate" +) + +const CRL_MAX_CLIENT_REMEMBER_SEC float64 = 60. + +type clientRateLimiter struct { + limits map[string]*rate.Limiter + maxConnsPerSec float64 +} + +func newClientRateLimiter(maxConnectionsPerMin float64) *clientRateLimiter { + return &clientRateLimiter{ + limits: make(map[string]*rate.Limiter), + maxConnsPerSec: maxConnectionsPerMin / 60., + } +} + +func (crl *clientRateLimiter) canConnect(clientIp string) bool { + limiter, ok := crl.limits[clientIp] + if !ok { + crl.limits[clientIp] = rate.NewLimiter(rate.Limit(crl.maxConnsPerSec), 1) + limiter = crl.limits[clientIp] + log.Println("new client: " + clientIp) + } + allowed := limiter.Allow() + if allowed { + log.Println("client " + clientIp + " was allowed") + } else { + log.Println("client " + clientIp + " has been rate limited") + } + crl._cleanup() + return allowed +} + +func (crl *clientRateLimiter) _cleanup() { + maps.DeleteFunc(crl.limits, func(clIp string, limit *rate.Limiter) bool { + // forget the client if they haven't connected in CRL_MAX_CLIENT_REMEMBER_SEC seconds + forgetClient := limit.Tokens() >= CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec + if forgetClient { + log.Println("Forgetting client "+clIp+": accumulated tokens ", limit.Tokens(), + "exceed limit of ", CRL_MAX_CLIENT_REMEMBER_SEC*crl.maxConnsPerSec) + } + return forgetClient + }) +} diff --git a/miniws/go.mod b/miniws/go.mod index 22a0144..29ea303 100644 --- a/miniws/go.mod +++ b/miniws/go.mod @@ -1,3 +1,7 @@ module github.com/shlldev/miniws/miniws go 1.22.2 + +require ( + golang.org/x/time v0.12.0 // indirect +) \ No newline at end of file diff --git a/miniws/logger.go b/miniws/logger.go index e7e335d..9a09f02 100644 --- a/miniws/logger.go +++ b/miniws/logger.go @@ -18,10 +18,10 @@ const ( type Logger struct { logFolder string - maxLogBytes int64 + maxLogBytes uint64 } -func NewLogger(logFolder_ string, maxLogBytes_ int64) *Logger { +func NewLogger(logFolder_ string, maxLogBytes_ uint64) *Logger { return &Logger{ logFolder: logFolder_, maxLogBytes: maxLogBytes_, @@ -70,7 +70,7 @@ func (l *Logger) writeToLogFileAndRenameIfBig(fileName, content string) { return } - if fileinfo.Size() > l.maxLogBytes { + if uint64(fileinfo.Size()) > l.maxLogBytes { var renamedFiledPath string = fullPath + "." + uuid.NewString() diff --git a/miniws/webserver.go b/miniws/webserver.go index f3ac1fc..6ebf36f 100644 --- a/miniws/webserver.go +++ b/miniws/webserver.go @@ -4,6 +4,7 @@ import ( "errors" "log" "mime" + "net" "net/http" "os" "path/filepath" @@ -22,40 +23,46 @@ const ( type FilterMode int +type WebServerConfig struct { + LogFolder string + ConfigFolder string + WWWFolder string + Port uint16 + MaxBytesPerLogFile uint64 + MaxConnectionsPerMinute uint64 +} + type WebServer struct { logger *Logger - port int - configFolder string - wwwFolder string + cfg WebServerConfig ipFilter []string userAgentFilter []string ipFilterMode FilterMode userAgentFilterMode FilterMode + clientLimiter *clientRateLimiter } -func NewWebServer(port_ int, logFolder_, configFolder_, wwwFolder_ string, maxLogBytes_ int64) *WebServer { - +func NewWebServer(cfg WebServerConfig) *WebServer { return &WebServer{ - logger: NewLogger(logFolder_, maxLogBytes_), - port: port_, - configFolder: configFolder_, - wwwFolder: wwwFolder_, + logger: NewLogger(cfg.LogFolder, cfg.MaxBytesPerLogFile), + cfg: cfg, ipFilter: make([]string, 0), userAgentFilter: make([]string, 0), ipFilterMode: FILTER_MODE_BLACKLIST, userAgentFilterMode: FILTER_MODE_BLACKLIST, + clientLimiter: newClientRateLimiter(float64(cfg.MaxConnectionsPerMinute)), } } func (ws *WebServer) Run() { - _, err := os.Lstat(ws.wwwFolder) + _, err := os.Lstat(ws.cfg.WWWFolder) if errors.Is(err, os.ErrNotExist) { - log.Fatalln("Fatal: www folder " + ws.wwwFolder + " does not exist") + log.Fatalln("Fatal: www folder " + ws.cfg.WWWFolder + " does not exist") } else if err != nil { log.Fatalln("Fatal: " + err.Error()) } - perms, err := fileperm.New(ws.wwwFolder) + perms, err := fileperm.New(ws.cfg.WWWFolder) if err != nil { log.Fatalln("Fatal: " + err.Error()) } @@ -67,8 +74,9 @@ func (ws *WebServer) Run() { ws.userAgentFilterMode, ws.userAgentFilter = ws.parseFilterPanics(FILENAME_USERAGENTFILTER) http.HandleFunc("/", ws.get) - log.Println("Server started on port " + strconv.Itoa(ws.port)) - http.ListenAndServe(":"+strconv.Itoa(ws.port), nil) + portStr := strconv.FormatUint(uint64(ws.cfg.Port), 10) + log.Println("Server started on port " + portStr) + http.ListenAndServe(":"+portStr, nil) } func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) { @@ -76,10 +84,10 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) { filterMode := FILTER_MODE_BLACKLIST filter := make([]string, 0) - os.Mkdir(ws.configFolder, PERMS_MKDIR) - fileinfo, err := os.Stat(filepath.Join(ws.configFolder, fileName)) + os.Mkdir(ws.cfg.ConfigFolder, PERMS_MKDIR) + fileinfo, err := os.Stat(filepath.Join(ws.cfg.ConfigFolder, fileName)) - fullPath := filepath.Join(ws.configFolder, fileName) + fullPath := filepath.Join(ws.cfg.ConfigFolder, fileName) if errors.Is(err, os.ErrNotExist) { os.Create(fullPath) fileinfo, err = os.Stat(fullPath) @@ -126,19 +134,23 @@ func (ws *WebServer) parseFilterPanics(fileName string) (FilterMode, []string) { } -func (ws *WebServer) isIpValid(ip string) bool { - ip = strings.Split(ip, ":")[0] // remove port +func (ws *WebServer) isIpAllowed(ip string) bool { + hostIp, _, err := net.SplitHostPort(ip) + if err != nil && !strings.Contains(ip, ":") { + log.Println(err) + return false + } switch ws.ipFilterMode { case FILTER_MODE_WHITELIST: - return slices.Contains(ws.ipFilter, ip) + return slices.Contains(ws.ipFilter, hostIp) case FILTER_MODE_BLACKLIST: - return !slices.Contains(ws.ipFilter, ip) + return !slices.Contains(ws.ipFilter, hostIp) default: return false //if something went wrong with conf parsing } } -func (ws *WebServer) isUserAgentValid(userAgent string) bool { +func (ws *WebServer) isUserAgentAllowed(userAgent string) bool { switch ws.userAgentFilterMode { case FILTER_MODE_WHITELIST: for _, userAgentFiltered := range ws.userAgentFilter { @@ -200,12 +212,19 @@ func (ws *WebServer) get(writer http.ResponseWriter, req *http.Request) { respStatusCode := http.StatusOK - if !ws.isIpValid(req.RemoteAddr) || !ws.isUserAgentValid(req.UserAgent()) { + // check that IP and User Agent of client are whitelisted, or not blacklisted + if !ws.isIpAllowed(req.RemoteAddr) || !ws.isUserAgentAllowed(req.UserAgent()) { + writer.WriteHeader(http.StatusForbidden) return } + // check that the client IP has not been sending too many requests recently + if !ws.clientLimiter.canConnect(req.RemoteAddr) { + writer.WriteHeader(http.StatusTooManyRequests) + return + } - fileToFetch := filepath.Join(ws.wwwFolder, req.URL.Path) + fileToFetch := filepath.Join(ws.cfg.WWWFolder, req.URL.Path) fetchedFile, fetchErr := ws.fetchFile(fileToFetch) fetchedFileStat, _ := fetchedFile.Stat() fetchedStat, _ := ws.fetchStat(fileToFetch)