blob: 2c10258c08a5f09536017bc153610c080b3f8e23 [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
Davit Tabidze5f00a392024-08-13 18:37:02 +04007 "embed"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04008 "encoding/json"
9 "flag"
10 "fmt"
Davit Tabidze5f00a392024-08-13 18:37:02 +040011 "html/template"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040012 "io"
13 "log"
14 "net/http"
15 "net/http/cookiejar"
16 "net/url"
Davit Tabidze5f00a392024-08-13 18:37:02 +040017 "slices"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040018 "strings"
19)
20
21var port = flag.Int("port", 3000, "Port to listen on")
22var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
23var loginAddr = flag.String("login-addr", "", "Login page address")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040024var membershipAddr = flag.String("membership-addr", "", "Group membership API endpoint")
Davit Tabidze5f00a392024-08-13 18:37:02 +040025var membershipPublicAddr = flag.String("membership-public-addr", "", "Public address of membership service")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040026var groups = flag.String("groups", "", "Comma separated list of groups. User must be part of at least one of them. If empty group membership will not be checked.")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040027var upstream = flag.String("upstream", "", "Upstream service address")
gioc81a8472024-09-24 13:06:19 +020028var noAuthPathPrefixes = flag.String("no-auth-path-prefixes", "", "Path prefixes to disable authentication for")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040029
Davit Tabidze5f00a392024-08-13 18:37:02 +040030//go:embed unauthorized.html
31var unauthorizedHTML embed.FS
32
33//go:embed static/*
34var f embed.FS
35
36type cachingHandler struct {
37 h http.Handler
38}
39
40func (h cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
41 w.Header().Set("Cache-Control", "max-age=604800")
42 h.h.ServeHTTP(w, r)
43}
44
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040045type user struct {
46 Identity struct {
47 Traits struct {
48 Username string `json:"username"`
49 } `json:"traits"`
50 } `json:"identity"`
51}
52
53type authError struct {
54 Error struct {
55 Status string `json:"status"`
56 } `json:"error"`
57}
58
59func getAddr(r *http.Request) (*url.URL, error) {
60 return url.Parse(fmt.Sprintf(
61 "%s://%s%s",
62 r.Header["X-Forwarded-Scheme"][0],
63 r.Header["X-Forwarded-Host"][0],
64 r.URL.RequestURI()))
65}
66
Davit Tabidze5f00a392024-08-13 18:37:02 +040067var funcMap = template.FuncMap{
68 "IsLast": func(index int, slice []string) bool {
69 return index == len(slice)-1
70 },
71}
72
73type UnauthorizedPageData struct {
74 MembershipPublicAddr string
75 Groups []string
76}
77
78func renderUnauthorizedPage(w http.ResponseWriter, groups []string) {
79 tmpl, err := template.New("unauthorized.html").Funcs(funcMap).ParseFS(unauthorizedHTML, "unauthorized.html")
80 if err != nil {
81 http.Error(w, "Failed to load template", http.StatusInternalServerError)
82 return
83 }
84 data := UnauthorizedPageData{
85 MembershipPublicAddr: *membershipPublicAddr,
86 Groups: groups,
87 }
88 w.Header().Set("Content-Type", "text/html")
89 w.WriteHeader(http.StatusUnauthorized)
90 if err := tmpl.Execute(w, data); err != nil {
91 http.Error(w, "Failed render template", http.StatusInternalServerError)
92 }
93}
94
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040095func handle(w http.ResponseWriter, r *http.Request) {
gioc81a8472024-09-24 13:06:19 +020096 reqAuth := true
97 for _, p := range strings.Split(*noAuthPathPrefixes, ",") {
98 if strings.HasPrefix(r.URL.Path, p) {
99 reqAuth = false
100 break
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400101 }
gioc81a8472024-09-24 13:06:19 +0200102 }
103 var user *user
104 if reqAuth {
105 var err error
106 user, err = queryWhoAmI(r.Cookies())
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400107 if err != nil {
108 http.Error(w, err.Error(), http.StatusInternalServerError)
109 return
110 }
gioc81a8472024-09-24 13:06:19 +0200111 if user == nil {
112 if r.Method != http.MethodGet {
113 http.Error(w, "Unauthorized", http.StatusUnauthorized)
114 return
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400115 }
gioc81a8472024-09-24 13:06:19 +0200116 curr, err := getAddr(r)
117 if err != nil {
118 http.Error(w, err.Error(), http.StatusInternalServerError)
119 return
120 }
121 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
122 http.Redirect(w, r, addr, http.StatusSeeOther)
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400123 return
124 }
gioc81a8472024-09-24 13:06:19 +0200125 if *groups != "" {
126 hasPermission := false
127 tg, err := getTransitiveGroups(user.Identity.Traits.Username)
128 if err != nil {
129 http.Error(w, err.Error(), http.StatusInternalServerError)
130 return
131 }
132 for _, i := range strings.Split(*groups, ",") {
133 if slices.Contains(tg, strings.TrimSpace(i)) {
134 hasPermission = true
135 break
136 }
137 }
138 if !hasPermission {
139 groupList := strings.Split(*groups, ",")
140 renderUnauthorizedPage(w, groupList)
141 return
142 }
143 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400144 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400145 rc := r.Clone(context.Background())
gioc81a8472024-09-24 13:06:19 +0200146 if user != nil {
147 // TODO(gio): Rename to X-Forwarded-User
148 rc.Header.Add("X-User", user.Identity.Traits.Username)
149 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400150 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
151 if err != nil {
152 http.Error(w, err.Error(), http.StatusInternalServerError)
153 return
154 }
155 rc.URL = ru
156 rc.RequestURI = ""
157 client := &http.Client{
158 Transport: &http.Transport{
159 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
160 },
161 CheckRedirect: func(req *http.Request, via []*http.Request) error {
162 return http.ErrUseLastResponse
163 },
164 }
165 resp, err := client.Do(rc)
166 if err != nil {
167 http.Error(w, err.Error(), http.StatusInternalServerError)
168 return
169 }
170 for name, values := range resp.Header {
171 for _, value := range values {
172 w.Header().Add(name, value)
173 }
174 }
175 w.WriteHeader(resp.StatusCode)
176 if _, err := io.Copy(w, resp.Body); err != nil {
177 http.Error(w, err.Error(), http.StatusInternalServerError)
178 return
179 }
180}
181
182func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
183 jar, err := cookiejar.New(nil)
184 if err != nil {
185 return nil, err
186 }
187 client := &http.Client{
188 Jar: jar,
189 Transport: &http.Transport{
190 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
191 },
192 }
193 addr, err := url.Parse(*whoAmIAddr)
194 if err != nil {
195 return nil, err
196 }
197 client.Jar.SetCookies(addr, cookies)
198 resp, err := client.Get(*whoAmIAddr)
199 if err != nil {
200 return nil, err
201 }
202 data := make(map[string]any)
203 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
204 return nil, err
205 }
206 // TODO(gio): remove debugging
207 b, err := json.MarshalIndent(data, "", " ")
208 if err != nil {
209 return nil, err
210 }
211 fmt.Println(string(b))
212 var buf bytes.Buffer
213 if err := json.NewEncoder(&buf).Encode(data); err != nil {
214 return nil, err
215 }
216 tmp := buf.String()
217 if resp.StatusCode == http.StatusOK {
218 u := &user{}
219 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
220 return nil, err
221 }
222 return u, nil
223 }
224 e := &authError{}
225 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
226 return nil, err
227 }
228 if e.Error.Status == "Unauthorized" {
229 return nil, nil
230 }
231 return nil, fmt.Errorf("Unknown error: %s", tmp)
232}
233
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400234type MembershipInfo struct {
235 MemberOf []string `json:"memberOf"`
236}
237
238func getTransitiveGroups(user string) ([]string, error) {
239 resp, err := http.Get(fmt.Sprintf("%s/%s", *membershipAddr, user))
240 if err != nil {
241 return nil, err
242 }
243 var info MembershipInfo
244 if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
245 return nil, err
246 }
247 return info.MemberOf, nil
248}
249
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400250func main() {
251 flag.Parse()
Davit Tabidze5f00a392024-08-13 18:37:02 +0400252 if *groups != "" && (*membershipAddr == "" || *membershipPublicAddr == "") {
253 log.Fatal("membership-addr and membership-public-addr flags are required when groups are provided")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400254 }
gioc81a8472024-09-24 13:06:19 +0200255 http.Handle("/.auth/static/", http.StripPrefix("/.auth", cachingHandler{http.FileServer(http.FS(f))}))
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400256 http.HandleFunc("/", handle)
257 fmt.Printf("Starting HTTP server on port: %d\n", *port)
258 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
259}