blob: f8ab6204b3dc0662a9fc4c576b326c887058faad [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
Davit Tabidze5f00a392024-08-13 18:37:02 +04007 "embed"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04008 "encoding/json"
9 "flag"
10 "fmt"
Davit Tabidze5f00a392024-08-13 18:37:02 +040011 "html/template"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040012 "io"
13 "log"
14 "net/http"
15 "net/http/cookiejar"
16 "net/url"
gio4fde4a12024-10-13 12:19:30 +040017 "regexp"
Davit Tabidze5f00a392024-08-13 18:37:02 +040018 "slices"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040019 "strings"
20)
21
22var port = flag.Int("port", 3000, "Port to listen on")
23var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
24var loginAddr = flag.String("login-addr", "", "Login page address")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040025var membershipAddr = flag.String("membership-addr", "", "Group membership API endpoint")
Davit Tabidze5f00a392024-08-13 18:37:02 +040026var membershipPublicAddr = flag.String("membership-public-addr", "", "Public address of membership service")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040027var groups = flag.String("groups", "", "Comma separated list of groups. User must be part of at least one of them. If empty group membership will not be checked.")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040028var upstream = flag.String("upstream", "", "Upstream service address")
gio4fde4a12024-10-13 12:19:30 +040029var noAuthPathPatterns = flag.String("no-auth-path-patterns", "", "Path regex patterns to disable authentication for")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040030
Davit Tabidze5f00a392024-08-13 18:37:02 +040031//go:embed unauthorized.html
32var unauthorizedHTML embed.FS
33
34//go:embed static/*
35var f embed.FS
36
gio4fde4a12024-10-13 12:19:30 +040037var noAuthPathRegexps []*regexp.Regexp
38
39func initPathPatterns() error {
40 for _, p := range strings.Split(*noAuthPathPatterns, ",") {
41 t := strings.TrimSpace(p)
42 if len(t) == 0 {
43 continue
44 }
45 exp, err := regexp.Compile(t)
46 if err != nil {
47 return err
48 }
49 noAuthPathRegexps = append(noAuthPathRegexps, exp)
50 }
51 return nil
52}
53
Davit Tabidze5f00a392024-08-13 18:37:02 +040054type cachingHandler struct {
55 h http.Handler
56}
57
58func (h cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
59 w.Header().Set("Cache-Control", "max-age=604800")
60 h.h.ServeHTTP(w, r)
61}
62
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040063type user struct {
64 Identity struct {
giodd213152024-09-27 11:26:59 +020065 Id string `json:"id"`
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040066 Traits struct {
67 Username string `json:"username"`
68 } `json:"traits"`
69 } `json:"identity"`
70}
71
72type authError struct {
73 Error struct {
74 Status string `json:"status"`
75 } `json:"error"`
76}
77
78func getAddr(r *http.Request) (*url.URL, error) {
79 return url.Parse(fmt.Sprintf(
80 "%s://%s%s",
81 r.Header["X-Forwarded-Scheme"][0],
82 r.Header["X-Forwarded-Host"][0],
83 r.URL.RequestURI()))
84}
85
Davit Tabidze5f00a392024-08-13 18:37:02 +040086var funcMap = template.FuncMap{
87 "IsLast": func(index int, slice []string) bool {
88 return index == len(slice)-1
89 },
90}
91
92type UnauthorizedPageData struct {
93 MembershipPublicAddr string
94 Groups []string
95}
96
97func renderUnauthorizedPage(w http.ResponseWriter, groups []string) {
98 tmpl, err := template.New("unauthorized.html").Funcs(funcMap).ParseFS(unauthorizedHTML, "unauthorized.html")
99 if err != nil {
100 http.Error(w, "Failed to load template", http.StatusInternalServerError)
101 return
102 }
103 data := UnauthorizedPageData{
104 MembershipPublicAddr: *membershipPublicAddr,
105 Groups: groups,
106 }
107 w.Header().Set("Content-Type", "text/html")
108 w.WriteHeader(http.StatusUnauthorized)
109 if err := tmpl.Execute(w, data); err != nil {
110 http.Error(w, "Failed render template", http.StatusInternalServerError)
111 }
112}
113
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400114func handle(w http.ResponseWriter, r *http.Request) {
gio9870cc02024-10-13 09:20:11 +0400115 user, err := queryWhoAmI(r.Cookies())
116 if err != nil {
117 http.Error(w, err.Error(), http.StatusInternalServerError)
118 return
119 }
gioc81a8472024-09-24 13:06:19 +0200120 reqAuth := true
gio4fde4a12024-10-13 12:19:30 +0400121 for _, p := range noAuthPathRegexps {
122 if p.MatchString(r.URL.Path) {
gioc81a8472024-09-24 13:06:19 +0200123 reqAuth = false
124 break
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400125 }
gioc81a8472024-09-24 13:06:19 +0200126 }
gioc81a8472024-09-24 13:06:19 +0200127 if reqAuth {
gioc81a8472024-09-24 13:06:19 +0200128 if user == nil {
129 if r.Method != http.MethodGet {
130 http.Error(w, "Unauthorized", http.StatusUnauthorized)
131 return
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400132 }
gioc81a8472024-09-24 13:06:19 +0200133 curr, err := getAddr(r)
134 if err != nil {
135 http.Error(w, err.Error(), http.StatusInternalServerError)
136 return
137 }
138 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
139 http.Redirect(w, r, addr, http.StatusSeeOther)
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400140 return
141 }
gioc81a8472024-09-24 13:06:19 +0200142 if *groups != "" {
143 hasPermission := false
144 tg, err := getTransitiveGroups(user.Identity.Traits.Username)
145 if err != nil {
146 http.Error(w, err.Error(), http.StatusInternalServerError)
147 return
148 }
149 for _, i := range strings.Split(*groups, ",") {
150 if slices.Contains(tg, strings.TrimSpace(i)) {
151 hasPermission = true
152 break
153 }
154 }
155 if !hasPermission {
156 groupList := strings.Split(*groups, ",")
157 renderUnauthorizedPage(w, groupList)
158 return
159 }
160 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400161 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400162 rc := r.Clone(context.Background())
gioc8faeac2024-10-09 15:26:16 +0400163 if user != nil {
164 rc.Header.Add("X-Forwarded-User", user.Identity.Traits.Username)
165 rc.Header.Add("X-Forwarded-UserId", user.Identity.Id)
166 } else {
167 delete(rc.Header, "X-Forwarded-User")
168 delete(rc.Header, "X-Forwarded-UserId")
169 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400170 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
171 if err != nil {
172 http.Error(w, err.Error(), http.StatusInternalServerError)
173 return
174 }
175 rc.URL = ru
176 rc.RequestURI = ""
177 client := &http.Client{
178 Transport: &http.Transport{
179 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
180 },
181 CheckRedirect: func(req *http.Request, via []*http.Request) error {
182 return http.ErrUseLastResponse
183 },
184 }
185 resp, err := client.Do(rc)
186 if err != nil {
187 http.Error(w, err.Error(), http.StatusInternalServerError)
188 return
189 }
190 for name, values := range resp.Header {
191 for _, value := range values {
192 w.Header().Add(name, value)
193 }
194 }
195 w.WriteHeader(resp.StatusCode)
196 if _, err := io.Copy(w, resp.Body); err != nil {
197 http.Error(w, err.Error(), http.StatusInternalServerError)
198 return
199 }
200}
201
202func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
203 jar, err := cookiejar.New(nil)
204 if err != nil {
205 return nil, err
206 }
207 client := &http.Client{
208 Jar: jar,
209 Transport: &http.Transport{
210 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
211 },
212 }
213 addr, err := url.Parse(*whoAmIAddr)
214 if err != nil {
215 return nil, err
216 }
217 client.Jar.SetCookies(addr, cookies)
218 resp, err := client.Get(*whoAmIAddr)
219 if err != nil {
220 return nil, err
221 }
222 data := make(map[string]any)
223 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
224 return nil, err
225 }
226 // TODO(gio): remove debugging
227 b, err := json.MarshalIndent(data, "", " ")
228 if err != nil {
229 return nil, err
230 }
231 fmt.Println(string(b))
232 var buf bytes.Buffer
233 if err := json.NewEncoder(&buf).Encode(data); err != nil {
234 return nil, err
235 }
236 tmp := buf.String()
237 if resp.StatusCode == http.StatusOK {
238 u := &user{}
239 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
240 return nil, err
241 }
242 return u, nil
243 }
244 e := &authError{}
245 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
246 return nil, err
247 }
248 if e.Error.Status == "Unauthorized" {
249 return nil, nil
250 }
251 return nil, fmt.Errorf("Unknown error: %s", tmp)
252}
253
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400254type MembershipInfo struct {
255 MemberOf []string `json:"memberOf"`
256}
257
258func getTransitiveGroups(user string) ([]string, error) {
259 resp, err := http.Get(fmt.Sprintf("%s/%s", *membershipAddr, user))
260 if err != nil {
261 return nil, err
262 }
263 var info MembershipInfo
264 if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
265 return nil, err
266 }
267 return info.MemberOf, nil
268}
269
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400270func main() {
271 flag.Parse()
Davit Tabidze5f00a392024-08-13 18:37:02 +0400272 if *groups != "" && (*membershipAddr == "" || *membershipPublicAddr == "") {
273 log.Fatal("membership-addr and membership-public-addr flags are required when groups are provided")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400274 }
gio4fde4a12024-10-13 12:19:30 +0400275 if err := initPathPatterns(); err != nil {
276 log.Fatal(err)
277 }
gioc81a8472024-09-24 13:06:19 +0200278 http.Handle("/.auth/static/", http.StripPrefix("/.auth", cachingHandler{http.FileServer(http.FS(f))}))
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400279 http.HandleFunc("/", handle)
280 fmt.Printf("Starting HTTP server on port: %d\n", *port)
281 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
282}