blob: 211412744884a6a7a198db322af8384a4855358e [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
Davit Tabidze5f00a392024-08-13 18:37:02 +04007 "embed"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04008 "encoding/json"
9 "flag"
10 "fmt"
Davit Tabidze5f00a392024-08-13 18:37:02 +040011 "html/template"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040012 "io"
13 "log"
14 "net/http"
15 "net/http/cookiejar"
16 "net/url"
Davit Tabidze5f00a392024-08-13 18:37:02 +040017 "slices"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040018 "strings"
19)
20
21var port = flag.Int("port", 3000, "Port to listen on")
22var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
23var loginAddr = flag.String("login-addr", "", "Login page address")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040024var membershipAddr = flag.String("membership-addr", "", "Group membership API endpoint")
Davit Tabidze5f00a392024-08-13 18:37:02 +040025var membershipPublicAddr = flag.String("membership-public-addr", "", "Public address of membership service")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040026var groups = flag.String("groups", "", "Comma separated list of groups. User must be part of at least one of them. If empty group membership will not be checked.")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040027var upstream = flag.String("upstream", "", "Upstream service address")
28
Davit Tabidze5f00a392024-08-13 18:37:02 +040029//go:embed unauthorized.html
30var unauthorizedHTML embed.FS
31
32//go:embed static/*
33var f embed.FS
34
35type cachingHandler struct {
36 h http.Handler
37}
38
39func (h cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
40 w.Header().Set("Cache-Control", "max-age=604800")
41 h.h.ServeHTTP(w, r)
42}
43
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040044type user struct {
45 Identity struct {
46 Traits struct {
47 Username string `json:"username"`
48 } `json:"traits"`
49 } `json:"identity"`
50}
51
52type authError struct {
53 Error struct {
54 Status string `json:"status"`
55 } `json:"error"`
56}
57
58func getAddr(r *http.Request) (*url.URL, error) {
59 return url.Parse(fmt.Sprintf(
60 "%s://%s%s",
61 r.Header["X-Forwarded-Scheme"][0],
62 r.Header["X-Forwarded-Host"][0],
63 r.URL.RequestURI()))
64}
65
Davit Tabidze5f00a392024-08-13 18:37:02 +040066var funcMap = template.FuncMap{
67 "IsLast": func(index int, slice []string) bool {
68 return index == len(slice)-1
69 },
70}
71
72type UnauthorizedPageData struct {
73 MembershipPublicAddr string
74 Groups []string
75}
76
77func renderUnauthorizedPage(w http.ResponseWriter, groups []string) {
78 tmpl, err := template.New("unauthorized.html").Funcs(funcMap).ParseFS(unauthorizedHTML, "unauthorized.html")
79 if err != nil {
80 http.Error(w, "Failed to load template", http.StatusInternalServerError)
81 return
82 }
83 data := UnauthorizedPageData{
84 MembershipPublicAddr: *membershipPublicAddr,
85 Groups: groups,
86 }
87 w.Header().Set("Content-Type", "text/html")
88 w.WriteHeader(http.StatusUnauthorized)
89 if err := tmpl.Execute(w, data); err != nil {
90 http.Error(w, "Failed render template", http.StatusInternalServerError)
91 }
92}
93
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040094func handle(w http.ResponseWriter, r *http.Request) {
95 user, err := queryWhoAmI(r.Cookies())
96 if err != nil {
97 http.Error(w, err.Error(), http.StatusInternalServerError)
98 return
99 }
100 if user == nil {
101 if r.Method != http.MethodGet {
102 http.Error(w, "Unauthorized", http.StatusUnauthorized)
103 return
104 }
105 curr, err := getAddr(r)
106 if err != nil {
107 http.Error(w, err.Error(), http.StatusInternalServerError)
108 return
109 }
110 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
111 http.Redirect(w, r, addr, http.StatusSeeOther)
112 return
113 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400114 if *groups != "" {
115 hasPermission := false
116 tg, err := getTransitiveGroups(user.Identity.Traits.Username)
117 if err != nil {
118 http.Error(w, err.Error(), http.StatusInternalServerError)
119 return
120 }
121 for _, i := range strings.Split(*groups, ",") {
122 if slices.Contains(tg, strings.TrimSpace(i)) {
123 hasPermission = true
124 break
125 }
126 }
127 if !hasPermission {
Davit Tabidze5f00a392024-08-13 18:37:02 +0400128 groupList := strings.Split(*groups, ",")
129 renderUnauthorizedPage(w, groupList)
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400130 return
131 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400132 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400133 rc := r.Clone(context.Background())
134 rc.Header.Add("X-User", user.Identity.Traits.Username)
135 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
136 if err != nil {
137 http.Error(w, err.Error(), http.StatusInternalServerError)
138 return
139 }
140 rc.URL = ru
141 rc.RequestURI = ""
142 client := &http.Client{
143 Transport: &http.Transport{
144 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
145 },
146 CheckRedirect: func(req *http.Request, via []*http.Request) error {
147 return http.ErrUseLastResponse
148 },
149 }
150 resp, err := client.Do(rc)
151 if err != nil {
152 http.Error(w, err.Error(), http.StatusInternalServerError)
153 return
154 }
155 for name, values := range resp.Header {
156 for _, value := range values {
157 w.Header().Add(name, value)
158 }
159 }
160 w.WriteHeader(resp.StatusCode)
161 if _, err := io.Copy(w, resp.Body); err != nil {
162 http.Error(w, err.Error(), http.StatusInternalServerError)
163 return
164 }
165}
166
167func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
168 jar, err := cookiejar.New(nil)
169 if err != nil {
170 return nil, err
171 }
172 client := &http.Client{
173 Jar: jar,
174 Transport: &http.Transport{
175 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
176 },
177 }
178 addr, err := url.Parse(*whoAmIAddr)
179 if err != nil {
180 return nil, err
181 }
182 client.Jar.SetCookies(addr, cookies)
183 resp, err := client.Get(*whoAmIAddr)
184 if err != nil {
185 return nil, err
186 }
187 data := make(map[string]any)
188 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
189 return nil, err
190 }
191 // TODO(gio): remove debugging
192 b, err := json.MarshalIndent(data, "", " ")
193 if err != nil {
194 return nil, err
195 }
196 fmt.Println(string(b))
197 var buf bytes.Buffer
198 if err := json.NewEncoder(&buf).Encode(data); err != nil {
199 return nil, err
200 }
201 tmp := buf.String()
202 if resp.StatusCode == http.StatusOK {
203 u := &user{}
204 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
205 return nil, err
206 }
207 return u, nil
208 }
209 e := &authError{}
210 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
211 return nil, err
212 }
213 if e.Error.Status == "Unauthorized" {
214 return nil, nil
215 }
216 return nil, fmt.Errorf("Unknown error: %s", tmp)
217}
218
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400219type MembershipInfo struct {
220 MemberOf []string `json:"memberOf"`
221}
222
223func getTransitiveGroups(user string) ([]string, error) {
224 resp, err := http.Get(fmt.Sprintf("%s/%s", *membershipAddr, user))
225 if err != nil {
226 return nil, err
227 }
228 var info MembershipInfo
229 if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
230 return nil, err
231 }
232 return info.MemberOf, nil
233}
234
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400235func main() {
236 flag.Parse()
Davit Tabidze5f00a392024-08-13 18:37:02 +0400237 if *groups != "" && (*membershipAddr == "" || *membershipPublicAddr == "") {
238 log.Fatal("membership-addr and membership-public-addr flags are required when groups are provided")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400239 }
Davit Tabidze5f00a392024-08-13 18:37:02 +0400240 http.Handle("/static/", cachingHandler{http.FileServer(http.FS(f))})
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400241 http.HandleFunc("/", handle)
242 fmt.Printf("Starting HTTP server on port: %d\n", *port)
243 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
244}