Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Commit da1a1c0

Browse files
committed
get header middleware
1 parent 759e085 commit da1a1c0

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

middleware_getheader.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package rye
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
type getHeader struct {
9+
headerName string
10+
contextKey string
11+
}
12+
13+
/*
14+
NewMiddlewareGetHeader creates a new handler to extract any header and save its value into the context.
15+
headerName: the name of the header you want to extract
16+
contextKey: the value key that you would like to store this header under in the context
17+
18+
Example usage:
19+
20+
routes.Handle("/some/route", a.Dependencies.MWHandler.Handle(
21+
[]rye.Handler{
22+
rye.NewMiddlewareGetHeader(headerName, contextKey),
23+
yourHandler,
24+
})).Methods("POST")
25+
*/
26+
func NewMiddlewareGetHeader(headerName, contextKey string) func(rw http.ResponseWriter, req *http.Request) *Response {
27+
h := getHeader{headerName: headerName, contextKey: contextKey}
28+
return h.getHeaderMiddleware
29+
}
30+
31+
func (h *getHeader) getHeaderMiddleware(rw http.ResponseWriter, r *http.Request) *Response {
32+
rID := r.Header.Get(h.headerName)
33+
if rID != "" {
34+
return &Response{
35+
Context: context.WithValue(r.Context(), h.contextKey, rID),
36+
}
37+
}
38+
39+
return nil
40+
}

middleware_getheader_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package rye
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
7+
. "github.com/onsi/ginkgo"
8+
. "github.com/onsi/gomega"
9+
)
10+
11+
var _ = Describe("Get Header Middleware", func() {
12+
var (
13+
request *http.Request
14+
response *httptest.ResponseRecorder
15+
)
16+
17+
BeforeEach(func() {
18+
response = httptest.NewRecorder()
19+
request = &http.Request{
20+
Header: make(map[string][]string, 0),
21+
}
22+
})
23+
24+
Describe("getHeaderMiddleware", func() {
25+
Context("when a valid header is passed", func() {
26+
It("should return context with value", func() {
27+
headerName := "SpecialHeader"
28+
ctxKey := "special"
29+
request.Header.Add(headerName, "secret value")
30+
resp := NewMiddlewareGetHeader(headerName, ctxKey)(response, request)
31+
Expect(resp).ToNot(BeNil())
32+
Expect(resp.Context).ToNot(BeNil())
33+
Expect(resp.Context.Value(ctxKey)).To(Equal("secret value"))
34+
})
35+
})
36+
37+
Context("when no header is passed", func() {
38+
It("should have no value in context", func() {
39+
resp := NewMiddlewareGetHeader("something", "not there")(response, request)
40+
Expect(resp).To(BeNil())
41+
})
42+
})
43+
})
44+
})

0 commit comments

Comments
 (0)