add tests

This commit is contained in:
Niccolo Borgioli 2024-05-25 18:27:53 +02:00
parent 393027f06c
commit 4669b2602c
No known key found for this signature in database
GPG Key ID: 4897ACD13A65977C
5 changed files with 262 additions and 105 deletions

View File

@ -6,15 +6,21 @@ import (
"path/filepath" "path/filepath"
) )
var DIR string = ""
func getFilepath(filename string) string { func getFilepath(filename string) string {
if len(os.Args) != 3 { if DIR == "" {
log.Fatal("Not enough args")
if len(os.Args) != 3 {
log.Fatal("Not enough args")
}
dir, err := filepath.Abs(os.Args[2])
if err != nil {
log.Fatal(err)
}
DIR = dir
} }
dir, err := filepath.Abs(os.Args[2]) return filepath.Join(DIR, filename)
if err != nil {
log.Fatal(err)
}
return filepath.Join(dir, filename)
} }
func readFile(filename string) ([]byte, bool) { func readFile(filename string) ([]byte, bool) {

View File

@ -13,17 +13,17 @@ const (
HTTPDelimiter = "\r\n" HTTPDelimiter = "\r\n"
) )
type Header struct { // type Header struct {
Name string // Name string
Value string // Value string
} // }
type Request struct { type Request struct {
Method string Method string
Path string Path string
Version string Version string
Body string Body string
BodyRaw []byte BodyRaw []byte
Headers []Header Headers map[string]string
} }
type HttpCode struct { type HttpCode struct {
@ -43,7 +43,7 @@ type Response struct {
Version string Version string
Body string Body string
BodyRaw []byte BodyRaw []byte
Headers []Header Headers map[string]string
} }
type StringRoute struct { type StringRoute struct {
@ -63,27 +63,27 @@ type Routes struct {
regexpRoutes []RegexRoute regexpRoutes []RegexRoute
} }
func Respond(conn net.Conn, response Response) { func Respond(conn net.Conn, req Request, res Response) {
fmt.Fprintf(conn, "%s %d %s%s", response.Version, response.Code.Code, response.Code.Message, HTTPDelimiter) fmt.Fprintf(conn, "%s %d %s%s", res.Version, res.Code.Code, res.Code.Message, HTTPDelimiter)
bodySize := 0 bodySize := 0
if response.Body != "" { if res.Body != "" {
bodySize = len(response.Body) bodySize = len(res.Body)
} else { } else {
bodySize = len(response.BodyRaw) bodySize = len(res.BodyRaw)
} }
if bodySize > 0 { if bodySize > 0 {
response.Headers = append(response.Headers, Header{Name: "Content-Length", Value: strconv.Itoa(bodySize)}) res.Headers["Content-Length"] = strconv.Itoa(bodySize)
} }
for _, header := range response.Headers { for header, value := range res.Headers {
fmt.Fprintf(conn, "%s: %s%s", header.Name, header.Value, HTTPDelimiter) fmt.Fprintf(conn, "%s: %s%s", header, value, HTTPDelimiter)
} }
fmt.Fprint(conn, HTTPDelimiter) fmt.Fprint(conn, HTTPDelimiter)
if bodySize > 0 { if bodySize > 0 {
if response.Body != "" { if res.Body != "" {
fmt.Fprint(conn, response.Body) fmt.Fprint(conn, res.Body)
} else { } else {
conn.Write(response.BodyRaw) conn.Write(res.BodyRaw)
} }
} }
} }
@ -97,7 +97,7 @@ func parseRequest(conn net.Conn) (Request, bool) {
contents := string(buffer[:n]) contents := string(buffer[:n])
parts := strings.Split(contents, HTTPDelimiter) parts := strings.Split(contents, HTTPDelimiter)
request := Request{} request := Request{Headers: map[string]string{}}
isBody := false isBody := false
for i, part := range parts { for i, part := range parts {
if i == 0 { if i == 0 {
@ -122,8 +122,7 @@ func parseRequest(conn net.Conn) (Request, bool) {
continue continue
} }
h := strings.SplitN(part, ": ", 2) h := strings.SplitN(part, ": ", 2)
header := Header{Name: h[0], Value: h[1]} request.Headers[h[0]] = h[1]
request.Headers = append(request.Headers, header)
} }
return request, true return request, true

72
app/routes.go Normal file
View File

@ -0,0 +1,72 @@
package main
import (
"regexp"
)
var routes = Routes{
stringRoutes: []StringRoute{
// ROOT
{path: "/", method: "GET", handler: func(req Request) Response {
return Response{Version: req.Version, Code: OK}
}},
// USER AGENT
{path: "/user-agent", method: "GET", handler: func(req Request) Response {
userAgent := req.Headers["User-Agent"]
if userAgent == "" {
return Response{Version: req.Version, Code: BadRequest}
}
return Response{
Version: req.Version,
Code: OK,
Body: userAgent,
Headers: map[string]string{"Content-Type": "text/plain"},
}
}},
},
regexpRoutes: []RegexRoute{
// PATH PARAMETER
{
regex: regexp.MustCompile(`^/echo/([A-Za-z]+)$`),
method: "GET",
handler: func(req Request, matches []string) Response {
return Response{
Version: req.Version,
Code: OK,
Body: matches[1],
Headers: map[string]string{"Content-Type": "text/plain"},
}
},
},
{
regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`),
method: "GET",
handler: func(req Request, matches []string) Response {
file, notFound := readFile(matches[1])
if notFound {
return Response{Version: req.Version, Code: NotFound}
}
return Response{
Version: req.Version,
Code: OK,
BodyRaw: file,
Headers: map[string]string{"Content-Type": "application/octet-stream"},
}
},
},
{
regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`),
method: "POST",
handler: func(req Request, matches []string) Response {
writeFile(matches[1], []byte(req.Body))
return Response{Version: req.Version, Code: Created}
},
},
},
}

View File

@ -4,38 +4,48 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"regexp"
) )
// type Handler = func(req Request, res Response)
// type Middleware = func(next Handler) Handler
// var m Middleware = func(next Handler) Handler {
// return func(req Request, res Response) {
// fmt.Println("Start")
// next(req, res)
// fmt.Println("End")
// }
// }
func handleConnection(conn net.Conn, routes Routes) { func handleConnection(conn net.Conn, routes Routes) {
defer conn.Close() defer conn.Close()
request, ok := parseRequest(conn) req, ok := parseRequest(conn)
if !ok { if !ok {
Respond(conn, Response{Version: "HTTP/1.1", Code: BadRequest}) Respond(conn, req, Response{Version: "HTTP/1.1", Code: BadRequest})
return return
} }
fmt.Println(request) fmt.Println(req)
for _, route := range routes.stringRoutes { for _, route := range routes.stringRoutes {
if request.Path == route.path && request.Method == route.method { if req.Path == route.path && req.Method == route.method {
Respond(conn, route.handler(request)) Respond(conn, req, route.handler(req))
return return
} }
} }
for _, route := range routes.regexpRoutes { for _, route := range routes.regexpRoutes {
if request.Method != route.method { if req.Method != route.method {
continue continue
} }
if matches := route.regex.FindStringSubmatch(request.Path); len(matches) > 0 { if matches := route.regex.FindStringSubmatch(req.Path); len(matches) > 0 {
Respond(conn, route.handler(request, matches)) Respond(conn, req, route.handler(req, matches))
return return
} }
} }
Respond(conn, Response{Version: request.Version, Code: NotFound}) Respond(conn, req, Response{Version: req.Version, Code: NotFound})
} }
func main() { func main() {
@ -47,74 +57,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
routes := Routes{
stringRoutes: []StringRoute{
// ROOT
{path: "/", method: "GET", handler: func(req Request) Response {
return Response{Version: req.Version, Code: OK}
}},
// USER AGENT
{path: "/user-agent", method: "GET", handler: func(req Request) Response {
for _, header := range req.Headers {
if header.Name != "User-Agent" {
continue
}
return Response{
Version: req.Version,
Code: OK,
Body: header.Value,
Headers: []Header{{Name: "Content-Type", Value: "text/plain"}},
}
}
return Response{Version: req.Version, Code: BadRequest}
}},
},
regexpRoutes: []RegexRoute{
// PATH PARAMETER
{
regex: regexp.MustCompile(`^/echo/([A-Za-z]+)$`),
method: "GET",
handler: func(req Request, matches []string) Response {
return Response{
Version: req.Version,
Code: OK,
Body: matches[1],
Headers: []Header{{Name: "Content-Type", Value: "text/plain"}},
}
},
},
{
regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`),
method: "GET",
handler: func(req Request, matches []string) Response {
file, notFound := readFile(matches[1])
if notFound {
return Response{Version: req.Version, Code: NotFound}
}
return Response{
Version: req.Version,
Code: OK,
BodyRaw: file,
Headers: []Header{{Name: "Content-Type", Value: "application/octet-stream"}},
}
},
},
{
regex: regexp.MustCompile(`^/files/([A-Za-z0-9_\-.]+)`),
method: "POST",
handler: func(req Request, matches []string) Response {
writeFile(matches[1], []byte(req.Body))
return Response{Version: req.Version, Code: Created}
},
},
},
}
for { for {
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {

138
app/server_test.go Normal file
View File

@ -0,0 +1,138 @@
package main
import (
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path"
"strconv"
"strings"
"testing"
)
func getBody(body io.ReadCloser) []byte {
defer body.Close()
bodyBytes, err := io.ReadAll(body)
if err != nil {
log.Fatalf("Failed to read res body: %v", err)
}
return bodyBytes
}
type Expected struct {
status int
body string
headers map[string]string
}
func checkResponse(t *testing.T, res *http.Response, expected Expected) {
if res.StatusCode != expected.status {
t.Errorf("Expected status code %d, received %d", expected.status, res.StatusCode)
}
body := string(getBody(res.Body))
if body != expected.body {
t.Errorf(`Expected body to be "%s" but got "%s"`, expected.body, body)
}
for header, value := range expected.headers {
if actual := res.Header[header][0]; actual != value {
t.Errorf(`Expected "%s" header to be "%s" but got "%s"`, header, value, actual)
}
}
}
func TestRoot(t *testing.T) {
res, _ := http.Get("http://localhost:4221")
checkResponse(t, res, Expected{status: 200})
}
func TestNotFound(t *testing.T) {
res, _ := http.Get("http://localhost:4221/foo")
checkResponse(t, res, Expected{status: 404})
}
func TestEcho(t *testing.T) {
input := "abc"
res, _ := http.Get(fmt.Sprintf("http://localhost:4221/echo/%s", input))
checkResponse(t, res, Expected{status: 200, body: input, headers: map[string]string{
"Content-Length": strconv.Itoa(len(input)),
"Content-Type": "text/plain",
}})
}
func TestUserAgent(t *testing.T) {
input := "CodeCrafters/1.0"
req, _ := http.NewRequest("GET", "http://localhost:4221/user-agent", nil)
req.Header.Set("User-Agent", input)
client := &http.Client{}
res, _ := client.Do(req)
checkResponse(t, res, Expected{status: 200, body: input, headers: map[string]string{
"Content-Length": strconv.Itoa(len(input)),
"Content-Type": "text/plain",
}})
}
func TestUserAgentNoHeader(t *testing.T) {
req, _ := http.NewRequest("GET", "http://localhost:4221/user-agent", nil)
req.Header.Set("User-Agent", "")
client := &http.Client{}
res, _ := client.Do(req)
checkResponse(t, res, Expected{status: 400})
}
func TestReadFile(t *testing.T) {
input := "Hello World"
tmp, _ := os.CreateTemp("", "read.txt")
defer os.Remove(tmp.Name())
os.WriteFile(tmp.Name(), []byte(input), 0755)
DIR = path.Dir(tmp.Name())
res, _ := http.Get(fmt.Sprintf("http://localhost:4221/files/%s", path.Base(tmp.Name())))
checkResponse(t, res, Expected{status: 200, body: input, headers: map[string]string{
"Content-Type": "application/octet-stream",
"Content-Length": strconv.Itoa(len(input)),
}})
}
func TestWriteFile(t *testing.T) {
input := "Hello World"
tmp, _ := os.CreateTemp("", "write.txt")
defer os.Remove(tmp.Name())
DIR = path.Dir(tmp.Name())
res, _ := http.Post(fmt.Sprintf("http://localhost:4221/files/%s", path.Base(tmp.Name())), "application/octet-stream", strings.NewReader(input))
checkResponse(t, res, Expected{status: 201})
contents, _ := os.ReadFile(tmp.Name())
if string(contents) != input {
t.Errorf("Content written to file does not match the input")
}
}
func TestMain(m *testing.M) {
fmt.Println("Starting server")
l, err := net.Listen("tcp", "0.0.0.0:4221")
if err != nil {
fmt.Println("Failed to bind to port 4221")
os.Exit(1)
}
go func() {
for {
conn, err := l.Accept()
if err == nil {
go handleConnection(conn, routes)
}
}
}()
code := m.Run()
fmt.Println("Stopping server")
l.Close()
os.Exit(code)
}