blob: 8b3d8372bb9a93bce79a6191e1bce3460fe6d184 [file] [log] [blame]
Giorgi Lekveishvili0ba5e402024-03-20 15:56:30 +04001package main
2
3import (
4 "bytes"
5 "context"
6 "crypto/tls"
7 "encoding/json"
8 "flag"
9 "fmt"
10 "io"
11 "log"
12 "net/http"
13 "net/http/cookiejar"
14 "net/url"
15 "strings"
16)
17
18var port = flag.Int("port", 3000, "Port to listen on")
19var whoAmIAddr = flag.String("whoami-addr", "", "Kratos whoami endpoint address")
20var loginAddr = flag.String("login-addr", "", "Login page address")
21var upstream = flag.String("upstream", "", "Upstream service address")
22
23type user struct {
24 Identity struct {
25 Traits struct {
26 Username string `json:"username"`
27 } `json:"traits"`
28 } `json:"identity"`
29}
30
31type authError struct {
32 Error struct {
33 Status string `json:"status"`
34 } `json:"error"`
35}
36
37func getAddr(r *http.Request) (*url.URL, error) {
38 return url.Parse(fmt.Sprintf(
39 "%s://%s%s",
40 r.Header["X-Forwarded-Scheme"][0],
41 r.Header["X-Forwarded-Host"][0],
42 r.URL.RequestURI()))
43}
44
45func handle(w http.ResponseWriter, r *http.Request) {
46 user, err := queryWhoAmI(r.Cookies())
47 if err != nil {
48 http.Error(w, err.Error(), http.StatusInternalServerError)
49 return
50 }
51 if user == nil {
52 if r.Method != http.MethodGet {
53 http.Error(w, "Unauthorized", http.StatusUnauthorized)
54 return
55 }
56 curr, err := getAddr(r)
57 if err != nil {
58 http.Error(w, err.Error(), http.StatusInternalServerError)
59 return
60 }
61 addr := fmt.Sprintf("%s?return_to=%s", *loginAddr, curr.String())
62 http.Redirect(w, r, addr, http.StatusSeeOther)
63 return
64 }
65 rc := r.Clone(context.Background())
66 rc.Header.Add("X-User", user.Identity.Traits.Username)
67 ru, err := url.Parse(fmt.Sprintf("http://%s%s", *upstream, r.URL.RequestURI()))
68 if err != nil {
69 http.Error(w, err.Error(), http.StatusInternalServerError)
70 return
71 }
72 rc.URL = ru
73 rc.RequestURI = ""
74 client := &http.Client{
75 Transport: &http.Transport{
76 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
77 },
78 CheckRedirect: func(req *http.Request, via []*http.Request) error {
79 return http.ErrUseLastResponse
80 },
81 }
82 resp, err := client.Do(rc)
83 if err != nil {
84 http.Error(w, err.Error(), http.StatusInternalServerError)
85 return
86 }
87 for name, values := range resp.Header {
88 for _, value := range values {
89 w.Header().Add(name, value)
90 }
91 }
92 w.WriteHeader(resp.StatusCode)
93 if _, err := io.Copy(w, resp.Body); err != nil {
94 http.Error(w, err.Error(), http.StatusInternalServerError)
95 return
96 }
97}
98
99func queryWhoAmI(cookies []*http.Cookie) (*user, error) {
100 jar, err := cookiejar.New(nil)
101 if err != nil {
102 return nil, err
103 }
104 client := &http.Client{
105 Jar: jar,
106 Transport: &http.Transport{
107 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
108 },
109 }
110 addr, err := url.Parse(*whoAmIAddr)
111 if err != nil {
112 return nil, err
113 }
114 client.Jar.SetCookies(addr, cookies)
115 resp, err := client.Get(*whoAmIAddr)
116 if err != nil {
117 return nil, err
118 }
119 data := make(map[string]any)
120 if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
121 return nil, err
122 }
123 // TODO(gio): remove debugging
124 b, err := json.MarshalIndent(data, "", " ")
125 if err != nil {
126 return nil, err
127 }
128 fmt.Println(string(b))
129 var buf bytes.Buffer
130 if err := json.NewEncoder(&buf).Encode(data); err != nil {
131 return nil, err
132 }
133 tmp := buf.String()
134 if resp.StatusCode == http.StatusOK {
135 u := &user{}
136 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(u); err != nil {
137 return nil, err
138 }
139 return u, nil
140 }
141 e := &authError{}
142 if err := json.NewDecoder(strings.NewReader(tmp)).Decode(e); err != nil {
143 return nil, err
144 }
145 if e.Error.Status == "Unauthorized" {
146 return nil, nil
147 }
148 return nil, fmt.Errorf("Unknown error: %s", tmp)
149}
150
151func main() {
152 flag.Parse()
153 http.HandleFunc("/", handle)
154 fmt.Printf("Starting HTTP server on port: %d\n", *port)
155 log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil))
156}