From 4669b2602c310a4f1efef1f550bc16c214622b26 Mon Sep 17 00:00:00 2001 From: Niccolo Borgioli Date: Sat, 25 May 2024 18:27:53 +0200 Subject: [PATCH] add tests --- app/file.go | 20 ++++--- app/http.go | 39 +++++++------ app/routes.go | 72 +++++++++++++++++++++++ app/server.go | 98 +++++++------------------------- app/server_test.go | 138 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 262 insertions(+), 105 deletions(-) create mode 100644 app/routes.go create mode 100644 app/server_test.go diff --git a/app/file.go b/app/file.go index c64414c..facf823 100644 --- a/app/file.go +++ b/app/file.go @@ -6,15 +6,21 @@ import ( "path/filepath" ) +var DIR string = "" + func getFilepath(filename string) string { - if len(os.Args) != 3 { - log.Fatal("Not enough args") + if DIR == "" { + + if len(os.Args) != 3 { + log.Fatal("Not enough args") + } + dir, err := filepath.Abs(os.Args[2]) + if err != nil { + log.Fatal(err) + } + DIR = dir } - dir, err := filepath.Abs(os.Args[2]) - if err != nil { - log.Fatal(err) - } - return filepath.Join(dir, filename) + return filepath.Join(DIR, filename) } func readFile(filename string) ([]byte, bool) { diff --git a/app/http.go b/app/http.go index fac2882..488bd6a 100644 --- a/app/http.go +++ b/app/http.go @@ -13,17 +13,17 @@ const ( HTTPDelimiter = "\r\n" ) -type Header struct { - Name string - Value string -} +// type Header struct { +// Name string +// Value string +// } type Request struct { Method string Path string Version string Body string BodyRaw []byte - Headers []Header + Headers map[string]string } type HttpCode struct { @@ -43,7 +43,7 @@ type Response struct { Version string Body string BodyRaw []byte - Headers []Header + Headers map[string]string } type StringRoute struct { @@ -63,27 +63,27 @@ type Routes struct { regexpRoutes []RegexRoute } -func Respond(conn net.Conn, response Response) { - fmt.Fprintf(conn, "%s %d %s%s", response.Version, response.Code.Code, response.Code.Message, HTTPDelimiter) +func Respond(conn net.Conn, req Request, res Response) { + fmt.Fprintf(conn, "%s %d %s%s", res.Version, res.Code.Code, res.Code.Message, HTTPDelimiter) bodySize := 0 - if response.Body != "" { - bodySize = len(response.Body) + if res.Body != "" { + bodySize = len(res.Body) } else { - bodySize = len(response.BodyRaw) + bodySize = len(res.BodyRaw) } if bodySize > 0 { - response.Headers = append(response.Headers, Header{Name: "Content-Length", Value: strconv.Itoa(bodySize)}) + res.Headers["Content-Length"] = strconv.Itoa(bodySize) } - for _, header := range response.Headers { - fmt.Fprintf(conn, "%s: %s%s", header.Name, header.Value, HTTPDelimiter) + for header, value := range res.Headers { + fmt.Fprintf(conn, "%s: %s%s", header, value, HTTPDelimiter) } fmt.Fprint(conn, HTTPDelimiter) if bodySize > 0 { - if response.Body != "" { - fmt.Fprint(conn, response.Body) + if res.Body != "" { + fmt.Fprint(conn, res.Body) } else { - conn.Write(response.BodyRaw) + conn.Write(res.BodyRaw) } } } @@ -97,7 +97,7 @@ func parseRequest(conn net.Conn) (Request, bool) { contents := string(buffer[:n]) parts := strings.Split(contents, HTTPDelimiter) - request := Request{} + request := Request{Headers: map[string]string{}} isBody := false for i, part := range parts { if i == 0 { @@ -122,8 +122,7 @@ func parseRequest(conn net.Conn) (Request, bool) { continue } h := strings.SplitN(part, ": ", 2) - header := Header{Name: h[0], Value: h[1]} - request.Headers = append(request.Headers, header) + request.Headers[h[0]] = h[1] } return request, true diff --git a/app/routes.go b/app/routes.go new file mode 100644 index 0000000..3d0514f --- /dev/null +++ b/app/routes.go @@ -0,0 +1,72 @@ +package main + +import ( + "regexp" +) + +var routes = Routes{ + stringRoutes: []StringRoute{ + // ROOT + {path: "/", method: "GET", handler: func(req Request) Response { + return Response{Version: req.Version, Code: OK} + }}, + + // USER AGENT + {path: "/user-agent", method: "GET", handler: func(req Request) Response { + userAgent := req.Headers["User-Agent"] + if userAgent == "" { + return Response{Version: req.Version, Code: BadRequest} + } + + return Response{ + Version: req.Version, + Code: OK, + Body: userAgent, + Headers: map[string]string{"Content-Type": "text/plain"}, + } + }}, + }, + + regexpRoutes: []RegexRoute{ + + // PATH PARAMETER + { + regex: regexp.MustCompile(`^/echo/([A-Za-z]+)$`), + method: "GET", + handler: func(req Request, matches []string) Response { + return Response{ + Version: req.Version, + Code: OK, + Body: matches[1], + Headers: map[string]string{"Content-Type": "text/plain"}, + } + }, + }, + + { + regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`), + method: "GET", + handler: func(req Request, matches []string) Response { + file, notFound := readFile(matches[1]) + if notFound { + return Response{Version: req.Version, Code: NotFound} + } + return Response{ + Version: req.Version, + Code: OK, + BodyRaw: file, + Headers: map[string]string{"Content-Type": "application/octet-stream"}, + } + }, + }, + + { + regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`), + method: "POST", + handler: func(req Request, matches []string) Response { + writeFile(matches[1], []byte(req.Body)) + return Response{Version: req.Version, Code: Created} + }, + }, + }, +} diff --git a/app/server.go b/app/server.go index 0411b8f..65a576b 100644 --- a/app/server.go +++ b/app/server.go @@ -4,38 +4,48 @@ import ( "fmt" "net" "os" - "regexp" ) +// type Handler = func(req Request, res Response) +// type Middleware = func(next Handler) Handler + +// var m Middleware = func(next Handler) Handler { +// return func(req Request, res Response) { +// fmt.Println("Start") +// next(req, res) +// fmt.Println("End") +// } +// } + func handleConnection(conn net.Conn, routes Routes) { defer conn.Close() - request, ok := parseRequest(conn) + req, ok := parseRequest(conn) if !ok { - Respond(conn, Response{Version: "HTTP/1.1", Code: BadRequest}) + Respond(conn, req, Response{Version: "HTTP/1.1", Code: BadRequest}) return } - fmt.Println(request) + fmt.Println(req) for _, route := range routes.stringRoutes { - if request.Path == route.path && request.Method == route.method { - Respond(conn, route.handler(request)) + if req.Path == route.path && req.Method == route.method { + Respond(conn, req, route.handler(req)) return } } for _, route := range routes.regexpRoutes { - if request.Method != route.method { + if req.Method != route.method { continue } - if matches := route.regex.FindStringSubmatch(request.Path); len(matches) > 0 { - Respond(conn, route.handler(request, matches)) + if matches := route.regex.FindStringSubmatch(req.Path); len(matches) > 0 { + Respond(conn, req, route.handler(req, matches)) return } } - Respond(conn, Response{Version: request.Version, Code: NotFound}) + Respond(conn, req, Response{Version: req.Version, Code: NotFound}) } func main() { @@ -47,74 +57,6 @@ func main() { os.Exit(1) } - routes := Routes{ - stringRoutes: []StringRoute{ - // ROOT - {path: "/", method: "GET", handler: func(req Request) Response { - return Response{Version: req.Version, Code: OK} - }}, - - // USER AGENT - {path: "/user-agent", method: "GET", handler: func(req Request) Response { - for _, header := range req.Headers { - if header.Name != "User-Agent" { - continue - } - return Response{ - Version: req.Version, - Code: OK, - Body: header.Value, - Headers: []Header{{Name: "Content-Type", Value: "text/plain"}}, - } - } - return Response{Version: req.Version, Code: BadRequest} - }}, - }, - - regexpRoutes: []RegexRoute{ - - // PATH PARAMETER - { - regex: regexp.MustCompile(`^/echo/([A-Za-z]+)$`), - method: "GET", - handler: func(req Request, matches []string) Response { - return Response{ - Version: req.Version, - Code: OK, - Body: matches[1], - Headers: []Header{{Name: "Content-Type", Value: "text/plain"}}, - } - }, - }, - - { - regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`), - method: "GET", - handler: func(req Request, matches []string) Response { - file, notFound := readFile(matches[1]) - if notFound { - return Response{Version: req.Version, Code: NotFound} - } - return Response{ - Version: req.Version, - Code: OK, - BodyRaw: file, - Headers: []Header{{Name: "Content-Type", Value: "application/octet-stream"}}, - } - }, - }, - - { - regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`), - method: "POST", - handler: func(req Request, matches []string) Response { - writeFile(matches[1], []byte(req.Body)) - return Response{Version: req.Version, Code: Created} - }, - }, - }, - } - for { conn, err := l.Accept() if err != nil { diff --git a/app/server_test.go b/app/server_test.go new file mode 100644 index 0000000..8b74c90 --- /dev/null +++ b/app/server_test.go @@ -0,0 +1,138 @@ +package main + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "path" + "strconv" + "strings" + "testing" +) + +func getBody(body io.ReadCloser) []byte { + defer body.Close() + bodyBytes, err := io.ReadAll(body) + if err != nil { + log.Fatalf("Failed to read res body: %v", err) + } + return bodyBytes +} + +type Expected struct { + status int + body string + headers map[string]string +} + +func checkResponse(t *testing.T, res *http.Response, expected Expected) { + if res.StatusCode != expected.status { + t.Errorf("Expected status code %d, received %d", expected.status, res.StatusCode) + } + + body := string(getBody(res.Body)) + if body != expected.body { + t.Errorf(`Expected body to be "%s" but got "%s"`, expected.body, body) + } + + for header, value := range expected.headers { + if actual := res.Header[header][0]; actual != value { + t.Errorf(`Expected "%s" header to be "%s" but got "%s"`, header, value, actual) + } + } +} + +func TestRoot(t *testing.T) { + res, _ := http.Get("http://localhost:4221") + checkResponse(t, res, Expected{status: 200}) +} + +func TestNotFound(t *testing.T) { + res, _ := http.Get("http://localhost:4221/foo") + checkResponse(t, res, Expected{status: 404}) +} + +func TestEcho(t *testing.T) { + input := "abc" + res, _ := http.Get(fmt.Sprintf("http://localhost:4221/echo/%s", input)) + checkResponse(t, res, Expected{status: 200, body: input, headers: map[string]string{ + "Content-Length": strconv.Itoa(len(input)), + "Content-Type": "text/plain", + }}) +} + +func TestUserAgent(t *testing.T) { + input := "CodeCrafters/1.0" + req, _ := http.NewRequest("GET", "http://localhost:4221/user-agent", nil) + req.Header.Set("User-Agent", input) + client := &http.Client{} + res, _ := client.Do(req) + checkResponse(t, res, Expected{status: 200, body: input, headers: map[string]string{ + "Content-Length": strconv.Itoa(len(input)), + "Content-Type": "text/plain", + }}) +} +func TestUserAgentNoHeader(t *testing.T) { + req, _ := http.NewRequest("GET", "http://localhost:4221/user-agent", nil) + req.Header.Set("User-Agent", "") + client := &http.Client{} + res, _ := client.Do(req) + checkResponse(t, res, Expected{status: 400}) +} + +func TestReadFile(t *testing.T) { + input := "Hello World" + tmp, _ := os.CreateTemp("", "read.txt") + defer os.Remove(tmp.Name()) + os.WriteFile(tmp.Name(), []byte(input), 0755) + DIR = path.Dir(tmp.Name()) + + res, _ := http.Get(fmt.Sprintf("http://localhost:4221/files/%s", path.Base(tmp.Name()))) + checkResponse(t, res, Expected{status: 200, body: input, headers: map[string]string{ + "Content-Type": "application/octet-stream", + "Content-Length": strconv.Itoa(len(input)), + }}) +} + +func TestWriteFile(t *testing.T) { + input := "Hello World" + tmp, _ := os.CreateTemp("", "write.txt") + defer os.Remove(tmp.Name()) + DIR = path.Dir(tmp.Name()) + + res, _ := http.Post(fmt.Sprintf("http://localhost:4221/files/%s", path.Base(tmp.Name())), "application/octet-stream", strings.NewReader(input)) + checkResponse(t, res, Expected{status: 201}) + + contents, _ := os.ReadFile(tmp.Name()) + if string(contents) != input { + t.Errorf("Content written to file does not match the input") + } +} + +func TestMain(m *testing.M) { + fmt.Println("Starting server") + l, err := net.Listen("tcp", "0.0.0.0:4221") + if err != nil { + fmt.Println("Failed to bind to port 4221") + os.Exit(1) + } + + go func() { + for { + conn, err := l.Accept() + if err == nil { + go handleConnection(conn, routes) + } + } + }() + + code := m.Run() + + fmt.Println("Stopping server") + l.Close() + + os.Exit(code) +}