blob: adf789ccd11722f96f12f183aaec4b9a6bdd6bc2 [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
Davit Tabidze5f00a392024-08-13 18:37:02 +04007 "embed"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04008 "encoding/json"
9 "flag"
10 "fmt"
Davit Tabidze5f00a392024-08-13 18:37:02 +040011 "html/template"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040012 "io"
13 "log"
14 "net/http"
15 "net/http/cookiejar"
16 "net/url"
Davit Tabidze5f00a392024-08-13 18:37:02 +040017 "slices"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040018 "strings"
19)
20
21var port = flag.Int("port", 3000, "Port to listen on")
22var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
23var loginAddr = flag.String("login-addr", "", "Login page address")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040024var membershipAddr = flag.String("membership-addr", "", "Group membership API endpoint")
Davit Tabidze5f00a392024-08-13 18:37:02 +040025var membershipPublicAddr = flag.String("membership-public-addr", "", "Public address of membership service")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040026var groups = flag.String("groups", "", "Comma separated list of groups. User must be part of at least one of them. If empty group membership will not be checked.")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040027var upstream = flag.String("upstream", "", "Upstream service address")
gioc81a8472024-09-24 13:06:19 +020028var noAuthPathPrefixes = flag.String("no-auth-path-prefixes", "", "Path prefixes to disable authentication for")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040029
Davit Tabidze5f00a392024-08-13 18:37:02 +040030//go:embed unauthorized.html
31var unauthorizedHTML embed.FS
32
33//go:embed static/*
34var f embed.FS
35
36type cachingHandler struct {
37 h http.Handler
38}
39
40func (h cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
41 w.Header().Set("Cache-Control", "max-age=604800")
42 h.h.ServeHTTP(w, r)
43}
44
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040045type user struct {
46 Identity struct {
giodd213152024-09-27 11:26:59 +020047 Id string `json:"id"`
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040048 Traits struct {
49 Username string `json:"username"`
50 } `json:"traits"`
51 } `json:"identity"`
52}
53
54type authError struct {
55 Error struct {
56 Status string `json:"status"`
57 } `json:"error"`
58}
59
60func getAddr(r *http.Request) (*url.URL, error) {
61 return url.Parse(fmt.Sprintf(
62 "%s://%s%s",
63 r.Header["X-Forwarded-Scheme"][0],
64 r.Header["X-Forwarded-Host"][0],
65 r.URL.RequestURI()))
66}
67
Davit Tabidze5f00a392024-08-13 18:37:02 +040068var funcMap = template.FuncMap{
69 "IsLast": func(index int, slice []string) bool {
70 return index == len(slice)-1
71 },
72}
73
74type UnauthorizedPageData struct {
75 MembershipPublicAddr string
76 Groups []string
77}
78
79func renderUnauthorizedPage(w http.ResponseWriter, groups []string) {
80 tmpl, err := template.New("unauthorized.html").Funcs(funcMap).ParseFS(unauthorizedHTML, "unauthorized.html")
81 if err != nil {
82 http.Error(w, "Failed to load template", http.StatusInternalServerError)
83 return
84 }
85 data := UnauthorizedPageData{
86 MembershipPublicAddr: *membershipPublicAddr,
87 Groups: groups,
88 }
89 w.Header().Set("Content-Type", "text/html")
90 w.WriteHeader(http.StatusUnauthorized)
91 if err := tmpl.Execute(w, data); err != nil {
92 http.Error(w, "Failed render template", http.StatusInternalServerError)
93 }
94}
95
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040096func handle(w http.ResponseWriter, r *http.Request) {
gioc81a8472024-09-24 13:06:19 +020097 reqAuth := true
98 for _, p := range strings.Split(*noAuthPathPrefixes, ",") {
giodd213152024-09-27 11:26:59 +020099 t := strings.TrimSpace(p)
100 if len(t) > 0 && strings.HasPrefix(r.URL.Path, t) {
gioc81a8472024-09-24 13:06:19 +0200101 reqAuth = false
102 break
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400103 }
gioc81a8472024-09-24 13:06:19 +0200104 }
105 var user *user
106 if reqAuth {
107 var err error
108 user, err = queryWhoAmI(r.Cookies())
giodd213152024-09-27 11:26:59 +0200109 fmt.Printf("--- %+v\n", user)
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400110 if err != nil {
111 http.Error(w, err.Error(), http.StatusInternalServerError)
112 return
113 }
gioc81a8472024-09-24 13:06:19 +0200114 if user == nil {
115 if r.Method != http.MethodGet {
116 http.Error(w, "Unauthorized", http.StatusUnauthorized)
117 return
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400118 }
gioc81a8472024-09-24 13:06:19 +0200119 curr, err := getAddr(r)
120 if err != nil {
121 http.Error(w, err.Error(), http.StatusInternalServerError)
122 return
123 }
124 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
125 http.Redirect(w, r, addr, http.StatusSeeOther)
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400126 return
127 }
gioc81a8472024-09-24 13:06:19 +0200128 if *groups != "" {
129 hasPermission := false
130 tg, err := getTransitiveGroups(user.Identity.Traits.Username)
131 if err != nil {
132 http.Error(w, err.Error(), http.StatusInternalServerError)
133 return
134 }
135 for _, i := range strings.Split(*groups, ",") {
136 if slices.Contains(tg, strings.TrimSpace(i)) {
137 hasPermission = true
138 break
139 }
140 }
141 if !hasPermission {
142 groupList := strings.Split(*groups, ",")
143 renderUnauthorizedPage(w, groupList)
144 return
145 }
146 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400147 }
giodd213152024-09-27 11:26:59 +0200148 fmt.Printf("%+v\n", user)
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400149 rc := r.Clone(context.Background())
gioc8faeac2024-10-09 15:26:16 +0400150 if user != nil {
151 rc.Header.Add("X-Forwarded-User", user.Identity.Traits.Username)
152 rc.Header.Add("X-Forwarded-UserId", user.Identity.Id)
153 } else {
154 delete(rc.Header, "X-Forwarded-User")
155 delete(rc.Header, "X-Forwarded-UserId")
156 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400157 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
158 if err != nil {
159 http.Error(w, err.Error(), http.StatusInternalServerError)
160 return
161 }
162 rc.URL = ru
163 rc.RequestURI = ""
164 client := &http.Client{
165 Transport: &http.Transport{
166 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
167 },
168 CheckRedirect: func(req *http.Request, via []*http.Request) error {
169 return http.ErrUseLastResponse
170 },
171 }
172 resp, err := client.Do(rc)
173 if err != nil {
174 http.Error(w, err.Error(), http.StatusInternalServerError)
175 return
176 }
177 for name, values := range resp.Header {
178 for _, value := range values {
179 w.Header().Add(name, value)
180 }
181 }
182 w.WriteHeader(resp.StatusCode)
183 if _, err := io.Copy(w, resp.Body); err != nil {
184 http.Error(w, err.Error(), http.StatusInternalServerError)
185 return
186 }
187}
188
189func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
190 jar, err := cookiejar.New(nil)
191 if err != nil {
192 return nil, err
193 }
194 client := &http.Client{
195 Jar: jar,
196 Transport: &http.Transport{
197 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
198 },
199 }
200 addr, err := url.Parse(*whoAmIAddr)
201 if err != nil {
202 return nil, err
203 }
204 client.Jar.SetCookies(addr, cookies)
205 resp, err := client.Get(*whoAmIAddr)
206 if err != nil {
207 return nil, err
208 }
209 data := make(map[string]any)
210 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
211 return nil, err
212 }
213 // TODO(gio): remove debugging
214 b, err := json.MarshalIndent(data, "", " ")
215 if err != nil {
216 return nil, err
217 }
218 fmt.Println(string(b))
219 var buf bytes.Buffer
220 if err := json.NewEncoder(&buf).Encode(data); err != nil {
221 return nil, err
222 }
223 tmp := buf.String()
224 if resp.StatusCode == http.StatusOK {
225 u := &user{}
226 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
227 return nil, err
228 }
229 return u, nil
230 }
231 e := &authError{}
232 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
233 return nil, err
234 }
235 if e.Error.Status == "Unauthorized" {
236 return nil, nil
237 }
238 return nil, fmt.Errorf("Unknown error: %s", tmp)
239}
240
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400241type MembershipInfo struct {
242 MemberOf []string `json:"memberOf"`
243}
244
245func getTransitiveGroups(user string) ([]string, error) {
246 resp, err := http.Get(fmt.Sprintf("%s/%s", *membershipAddr, user))
247 if err != nil {
248 return nil, err
249 }
250 var info MembershipInfo
251 if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
252 return nil, err
253 }
254 return info.MemberOf, nil
255}
256
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400257func main() {
258 flag.Parse()
Davit Tabidze5f00a392024-08-13 18:37:02 +0400259 if *groups != "" && (*membershipAddr == "" || *membershipPublicAddr == "") {
260 log.Fatal("membership-addr and membership-public-addr flags are required when groups are provided")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400261 }
gioc81a8472024-09-24 13:06:19 +0200262 http.Handle("/.auth/static/", http.StripPrefix("/.auth", cachingHandler{http.FileServer(http.FS(f))}))
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400263 http.HandleFunc("/", handle)
264 fmt.Printf("Starting HTTP server on port: %d\n", *port)
265 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
266}