blob: a1470490d3122bddf57ce9a6e96fb927489df7d2 [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
Davit Tabidze5f00a392024-08-13 18:37:02 +04007 "embed"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04008 "encoding/json"
9 "flag"
10 "fmt"
Davit Tabidze5f00a392024-08-13 18:37:02 +040011 "html/template"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040012 "io"
13 "log"
14 "net/http"
15 "net/http/cookiejar"
16 "net/url"
Davit Tabidze5f00a392024-08-13 18:37:02 +040017 "slices"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040018 "strings"
19)
20
21var port = flag.Int("port", 3000, "Port to listen on")
22var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
23var loginAddr = flag.String("login-addr", "", "Login page address")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040024var membershipAddr = flag.String("membership-addr", "", "Group membership API endpoint")
Davit Tabidze5f00a392024-08-13 18:37:02 +040025var membershipPublicAddr = flag.String("membership-public-addr", "", "Public address of membership service")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040026var groups = flag.String("groups", "", "Comma separated list of groups. User must be part of at least one of them. If empty group membership will not be checked.")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040027var upstream = flag.String("upstream", "", "Upstream service address")
gioc81a8472024-09-24 13:06:19 +020028var noAuthPathPrefixes = flag.String("no-auth-path-prefixes", "", "Path prefixes to disable authentication for")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040029
Davit Tabidze5f00a392024-08-13 18:37:02 +040030//go:embed unauthorized.html
31var unauthorizedHTML embed.FS
32
33//go:embed static/*
34var f embed.FS
35
36type cachingHandler struct {
37 h http.Handler
38}
39
40func (h cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
41 w.Header().Set("Cache-Control", "max-age=604800")
42 h.h.ServeHTTP(w, r)
43}
44
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040045type user struct {
46 Identity struct {
giodd213152024-09-27 11:26:59 +020047 Id string `json:"id"`
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040048 Traits struct {
49 Username string `json:"username"`
50 } `json:"traits"`
51 } `json:"identity"`
52}
53
54type authError struct {
55 Error struct {
56 Status string `json:"status"`
57 } `json:"error"`
58}
59
60func getAddr(r *http.Request) (*url.URL, error) {
61 return url.Parse(fmt.Sprintf(
62 "%s://%s%s",
63 r.Header["X-Forwarded-Scheme"][0],
64 r.Header["X-Forwarded-Host"][0],
65 r.URL.RequestURI()))
66}
67
Davit Tabidze5f00a392024-08-13 18:37:02 +040068var funcMap = template.FuncMap{
69 "IsLast": func(index int, slice []string) bool {
70 return index == len(slice)-1
71 },
72}
73
74type UnauthorizedPageData struct {
75 MembershipPublicAddr string
76 Groups []string
77}
78
79func renderUnauthorizedPage(w http.ResponseWriter, groups []string) {
80 tmpl, err := template.New("unauthorized.html").Funcs(funcMap).ParseFS(unauthorizedHTML, "unauthorized.html")
81 if err != nil {
82 http.Error(w, "Failed to load template", http.StatusInternalServerError)
83 return
84 }
85 data := UnauthorizedPageData{
86 MembershipPublicAddr: *membershipPublicAddr,
87 Groups: groups,
88 }
89 w.Header().Set("Content-Type", "text/html")
90 w.WriteHeader(http.StatusUnauthorized)
91 if err := tmpl.Execute(w, data); err != nil {
92 http.Error(w, "Failed render template", http.StatusInternalServerError)
93 }
94}
95
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040096func handle(w http.ResponseWriter, r *http.Request) {
gio9870cc02024-10-13 09:20:11 +040097 user, err := queryWhoAmI(r.Cookies())
98 if err != nil {
99 http.Error(w, err.Error(), http.StatusInternalServerError)
100 return
101 }
gioc81a8472024-09-24 13:06:19 +0200102 reqAuth := true
103 for _, p := range strings.Split(*noAuthPathPrefixes, ",") {
giodd213152024-09-27 11:26:59 +0200104 t := strings.TrimSpace(p)
105 if len(t) > 0 && strings.HasPrefix(r.URL.Path, t) {
gioc81a8472024-09-24 13:06:19 +0200106 reqAuth = false
107 break
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400108 }
gioc81a8472024-09-24 13:06:19 +0200109 }
gioc81a8472024-09-24 13:06:19 +0200110 if reqAuth {
gioc81a8472024-09-24 13:06:19 +0200111 if user == nil {
112 if r.Method != http.MethodGet {
113 http.Error(w, "Unauthorized", http.StatusUnauthorized)
114 return
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400115 }
gioc81a8472024-09-24 13:06:19 +0200116 curr, err := getAddr(r)
117 if err != nil {
118 http.Error(w, err.Error(), http.StatusInternalServerError)
119 return
120 }
121 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
122 http.Redirect(w, r, addr, http.StatusSeeOther)
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400123 return
124 }
gioc81a8472024-09-24 13:06:19 +0200125 if *groups != "" {
126 hasPermission := false
127 tg, err := getTransitiveGroups(user.Identity.Traits.Username)
128 if err != nil {
129 http.Error(w, err.Error(), http.StatusInternalServerError)
130 return
131 }
132 for _, i := range strings.Split(*groups, ",") {
133 if slices.Contains(tg, strings.TrimSpace(i)) {
134 hasPermission = true
135 break
136 }
137 }
138 if !hasPermission {
139 groupList := strings.Split(*groups, ",")
140 renderUnauthorizedPage(w, groupList)
141 return
142 }
143 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400144 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400145 rc := r.Clone(context.Background())
gioc8faeac2024-10-09 15:26:16 +0400146 if user != nil {
147 rc.Header.Add("X-Forwarded-User", user.Identity.Traits.Username)
148 rc.Header.Add("X-Forwarded-UserId", user.Identity.Id)
149 } else {
150 delete(rc.Header, "X-Forwarded-User")
151 delete(rc.Header, "X-Forwarded-UserId")
152 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400153 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
154 if err != nil {
155 http.Error(w, err.Error(), http.StatusInternalServerError)
156 return
157 }
158 rc.URL = ru
159 rc.RequestURI = ""
160 client := &http.Client{
161 Transport: &http.Transport{
162 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
163 },
164 CheckRedirect: func(req *http.Request, via []*http.Request) error {
165 return http.ErrUseLastResponse
166 },
167 }
168 resp, err := client.Do(rc)
169 if err != nil {
170 http.Error(w, err.Error(), http.StatusInternalServerError)
171 return
172 }
173 for name, values := range resp.Header {
174 for _, value := range values {
175 w.Header().Add(name, value)
176 }
177 }
178 w.WriteHeader(resp.StatusCode)
179 if _, err := io.Copy(w, resp.Body); err != nil {
180 http.Error(w, err.Error(), http.StatusInternalServerError)
181 return
182 }
183}
184
185func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
186 jar, err := cookiejar.New(nil)
187 if err != nil {
188 return nil, err
189 }
190 client := &http.Client{
191 Jar: jar,
192 Transport: &http.Transport{
193 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
194 },
195 }
196 addr, err := url.Parse(*whoAmIAddr)
197 if err != nil {
198 return nil, err
199 }
200 client.Jar.SetCookies(addr, cookies)
201 resp, err := client.Get(*whoAmIAddr)
202 if err != nil {
203 return nil, err
204 }
205 data := make(map[string]any)
206 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
207 return nil, err
208 }
209 // TODO(gio): remove debugging
210 b, err := json.MarshalIndent(data, "", " ")
211 if err != nil {
212 return nil, err
213 }
214 fmt.Println(string(b))
215 var buf bytes.Buffer
216 if err := json.NewEncoder(&buf).Encode(data); err != nil {
217 return nil, err
218 }
219 tmp := buf.String()
220 if resp.StatusCode == http.StatusOK {
221 u := &user{}
222 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
223 return nil, err
224 }
225 return u, nil
226 }
227 e := &authError{}
228 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
229 return nil, err
230 }
231 if e.Error.Status == "Unauthorized" {
232 return nil, nil
233 }
234 return nil, fmt.Errorf("Unknown error: %s", tmp)
235}
236
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400237type MembershipInfo struct {
238 MemberOf []string `json:"memberOf"`
239}
240
241func getTransitiveGroups(user string) ([]string, error) {
242 resp, err := http.Get(fmt.Sprintf("%s/%s", *membershipAddr, user))
243 if err != nil {
244 return nil, err
245 }
246 var info MembershipInfo
247 if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
248 return nil, err
249 }
250 return info.MemberOf, nil
251}
252
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400253func main() {
254 flag.Parse()
Davit Tabidze5f00a392024-08-13 18:37:02 +0400255 if *groups != "" && (*membershipAddr == "" || *membershipPublicAddr == "") {
256 log.Fatal("membership-addr and membership-public-addr flags are required when groups are provided")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400257 }
gioc81a8472024-09-24 13:06:19 +0200258 http.Handle("/.auth/static/", http.StripPrefix("/.auth", cachingHandler{http.FileServer(http.FS(f))}))
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400259 http.HandleFunc("/", handle)
260 fmt.Printf("Starting HTTP server on port: %d\n", *port)
261 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
262}