blob: d1e1b49ee718e99e2732b2755229728e04e0de77 [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
7 "encoding/json"
8 "flag"
9 "fmt"
10 "io"
11 "log"
12 "net/http"
13 "net/http/cookiejar"
14 "net/url"
15 "strings"
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040016
17 "golang.org/x/exp/slices"
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040018)
19
20var port = flag.Int("port", 3000, "Port to listen on")
21var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
22var loginAddr = flag.String("login-addr", "", "Login page address")
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040023var membershipAddr = flag.String("membership-addr", "", "Group membership API endpoint")
24var groups = flag.String("groups", "", "Comma separated list of groups. User must be part of at least one of them. If empty group membership will not be checked.")
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040025var upstream = flag.String("upstream", "", "Upstream service address")
26
27type user struct {
28 Identity struct {
29 Traits struct {
30 Username string `json:"username"`
31 } `json:"traits"`
32 } `json:"identity"`
33}
34
35type authError struct {
36 Error struct {
37 Status string `json:"status"`
38 } `json:"error"`
39}
40
41func getAddr(r *http.Request) (*url.URL, error) {
42 return url.Parse(fmt.Sprintf(
43 "%s://%s%s",
44 r.Header["X-Forwarded-Scheme"][0],
45 r.Header["X-Forwarded-Host"][0],
46 r.URL.RequestURI()))
47}
48
49func handle(w http.ResponseWriter, r *http.Request) {
50 user, err := queryWhoAmI(r.Cookies())
51 if err != nil {
52 http.Error(w, err.Error(), http.StatusInternalServerError)
53 return
54 }
55 if user == nil {
56 if r.Method != http.MethodGet {
57 http.Error(w, "Unauthorized", http.StatusUnauthorized)
58 return
59 }
60 curr, err := getAddr(r)
61 if err != nil {
62 http.Error(w, err.Error(), http.StatusInternalServerError)
63 return
64 }
65 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
66 http.Redirect(w, r, addr, http.StatusSeeOther)
67 return
68 }
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +040069 if *groups != "" {
70 hasPermission := false
71 tg, err := getTransitiveGroups(user.Identity.Traits.Username)
72 if err != nil {
73 http.Error(w, err.Error(), http.StatusInternalServerError)
74 return
75 }
76 for _, i := range strings.Split(*groups, ",") {
77 if slices.Contains(tg, strings.TrimSpace(i)) {
78 hasPermission = true
79 break
80 }
81 }
82 if !hasPermission {
83 http.Error(w, "not authorized", http.StatusUnauthorized)
84 return
85 }
86
87 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +040088 rc := r.Clone(context.Background())
89 rc.Header.Add("X-User", user.Identity.Traits.Username)
90 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
91 if err != nil {
92 http.Error(w, err.Error(), http.StatusInternalServerError)
93 return
94 }
95 rc.URL = ru
96 rc.RequestURI = ""
97 client := &http.Client{
98 Transport: &http.Transport{
99 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
100 },
101 CheckRedirect: func(req *http.Request, via []*http.Request) error {
102 return http.ErrUseLastResponse
103 },
104 }
105 resp, err := client.Do(rc)
106 if err != nil {
107 http.Error(w, err.Error(), http.StatusInternalServerError)
108 return
109 }
110 for name, values := range resp.Header {
111 for _, value := range values {
112 w.Header().Add(name, value)
113 }
114 }
115 w.WriteHeader(resp.StatusCode)
116 if _, err := io.Copy(w, resp.Body); err != nil {
117 http.Error(w, err.Error(), http.StatusInternalServerError)
118 return
119 }
120}
121
122func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
123 jar, err := cookiejar.New(nil)
124 if err != nil {
125 return nil, err
126 }
127 client := &http.Client{
128 Jar: jar,
129 Transport: &http.Transport{
130 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
131 },
132 }
133 addr, err := url.Parse(*whoAmIAddr)
134 if err != nil {
135 return nil, err
136 }
137 client.Jar.SetCookies(addr, cookies)
138 resp, err := client.Get(*whoAmIAddr)
139 if err != nil {
140 return nil, err
141 }
142 data := make(map[string]any)
143 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
144 return nil, err
145 }
146 // TODO(gio): remove debugging
147 b, err := json.MarshalIndent(data, "", " ")
148 if err != nil {
149 return nil, err
150 }
151 fmt.Println(string(b))
152 var buf bytes.Buffer
153 if err := json.NewEncoder(&buf).Encode(data); err != nil {
154 return nil, err
155 }
156 tmp := buf.String()
157 if resp.StatusCode == http.StatusOK {
158 u := &user{}
159 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
160 return nil, err
161 }
162 return u, nil
163 }
164 e := &authError{}
165 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
166 return nil, err
167 }
168 if e.Error.Status == "Unauthorized" {
169 return nil, nil
170 }
171 return nil, fmt.Errorf("Unknown error: %s", tmp)
172}
173
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400174type MembershipInfo struct {
175 MemberOf []string `json:"memberOf"`
176}
177
178func getTransitiveGroups(user string) ([]string, error) {
179 resp, err := http.Get(fmt.Sprintf("%s/%s", *membershipAddr, user))
180 if err != nil {
181 return nil, err
182 }
183 var info MembershipInfo
184 if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
185 return nil, err
186 }
187 return info.MemberOf, nil
188}
189
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400190func main() {
191 flag.Parse()
Giorgi Lekveishvilia09fad72024-03-21 15:24:35 +0400192 if *groups != "" && *membershipAddr == "" {
193 log.Fatal("membership-addr flag is required when groups are provided")
194 }
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +0400195 http.HandleFunc("/", handle)
196 fmt.Printf("Starting HTTP server on port: %d\n", *port)
197 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
198}