membership: list child and parent groups (#107)

* Group page shows child and parent groups

* moved recursion as a helper

* reworked several funcs to return directly Group type

* added TransitiveGroups on homepage

* added circular reference detection

* improved performance of transitive groups

* minor fix

* GetAllTransitiveGroupsForGroup parameter changed to string

* cover getting transitive groups with unit test

* recursion logic needs fix

* refactor: reuse code for processing user/group transitive groups

* recursion fix

---------

Co-authored-by: Giorgi Lekveishvili <lekva@gl-mbp-m1-max.local>
diff --git a/core/auth/memberships/.gitignore b/core/auth/memberships/.gitignore
index 4042c7b..6ce208d 100644
--- a/core/auth/memberships/.gitignore
+++ b/core/auth/memberships/.gitignore
@@ -1,2 +1,2 @@
-# Exclude SQLite database file
 *.db
+memberships*
diff --git a/core/auth/memberships/Makefile b/core/auth/memberships/Makefile
index d846cd4..1defe7d 100644
--- a/core/auth/memberships/Makefile
+++ b/core/auth/memberships/Makefile
@@ -9,25 +9,31 @@
 clean:
 	rm -f memberships*
 
+build: clean
+	/usr/local/go/bin/go build -o memberships *.go
+
+test:
+	/usr/local/go/bin/go test ./...
+
 build_arm64: export CGO_ENABLED=0
 build_arm64: export GO111MODULE=on
 build_arm64: export GOOS=linux
 build_arm64: export GOARCH=arm64
-build_arm64:
+build_arm64: clean
 	/usr/local/go/bin/go build -o memberships_arm64 *.go
 
 build_amd64: export CGO_ENABLED=0
 build_amd64: export GO111MODULE=on
 build_amd64: export GOOS=linux
 build_amd64: export GOARCH=amd64
-build_amd64:
+build_amd64: clean
 	/usr/local/go/bin/go build -o memberships_amd64 *.go
 
-push_arm64: clean build_arm64
+push_arm64: build_arm64
 	$(podman) build --platform linux/arm64 --tag=$(repo_name)/memberships:arm64 .
 	$(podman) push $(repo_name)/memberships:arm64
 
-push_amd64: clean build_amd64
+push_amd64: build_amd64
 	$(podman) build --platform linux/amd64 --tag=$(repo_name)/memberships:amd64 .
 	$(podman) push $(repo_name)/memberships:amd64
 
diff --git a/core/auth/memberships/group.html b/core/auth/memberships/group.html
index 979777f..ac432d8 100644
--- a/core/auth/memberships/group.html
+++ b/core/auth/memberships/group.html
@@ -61,5 +61,31 @@
         </tr>
         {{- end }}
     </table>
+    <h4>Transitive Groups</h4>
+    <table>
+        <tr>
+            <th>Group Name</th>
+            <th>Description</th>
+        </tr>
+        {{- range .TransitiveGroups }}
+        <tr>
+            <td><a href="/group/{{ .Name }}">{{ .Name }}</a></td>
+            <td>{{ .Description }}</td>
+        </tr>
+        {{- end }}
+    </table>
+    <h3>Child Groups</h3>
+    <table>
+        <tr>
+            <th>Group Name</th>
+            <th>Description</th>
+        </tr>
+        {{- range .ChildGroups }}
+        <tr>
+            <td><a href="/group/{{ .Name }}">{{ .Name }}</a></td>
+            <td>{{ .Description }}</td>
+        </tr>
+        {{- end }}
+    </table>
 </body>
 </html>
diff --git a/core/auth/memberships/index.html b/core/auth/memberships/index.html
index f78ae1d..03ee299 100644
--- a/core/auth/memberships/index.html
+++ b/core/auth/memberships/index.html
@@ -41,5 +41,18 @@
         </tr>
         {{- end -}}
     </table>
+    <h4>Transitive Groups</h4>
+    <table>
+        <tr>
+            <th>Name</th>
+            <th>Description</th>
+        </tr>
+        {{- range .TransitiveGroups -}}
+        <tr>
+            <td><a href="/group/{{ .Name }}">{{ .Name }}</a></td>
+            <td>{{ .Description }}</td>
+        </tr>
+        {{- end -}}
+    </table>
 </body>
 </html>
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 0c1d104..4b64140 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -41,8 +41,10 @@
 	GetGroupMembers(group string) ([]string, error)
 	GetGroupDescription(group string) (string, error)
 	GetAvailableGroupsAsChild(group string) ([]string, error)
-	GetAllTransitiveGroupsForUser(user string) ([]string, error)
-	GetGroupsGroupBelongsTo(group string) ([]string, error)
+	GetAllTransitiveGroupsForUser(user string) ([]Group, error)
+	GetGroupsGroupBelongsTo(group string) ([]Group, error)
+	GetDirectChildrenGroups(group string) ([]Group, error)
+	GetAllTransitiveGroupsForGroup(group string) ([]Group, error)
 }
 
 type Server struct {
@@ -58,12 +60,8 @@
 	db *sql.DB
 }
 
-func NewSQLiteStore(path string) (*SQLiteStore, error) {
-	db, err := sql.Open("sqlite3", path)
-	if err != nil {
-		return nil, err
-	}
-	_, err = db.Exec(`
+func NewSQLiteStore(db *sql.DB) (*SQLiteStore, error) {
+	_, err := db.Exec(`
         CREATE TABLE IF NOT EXISTS groups (
             name TEXT PRIMARY KEY,
             description TEXT
@@ -270,6 +268,15 @@
 }
 
 func (s *SQLiteStore) AddChildGroup(parent, child string) error {
+	parentGroups, err := s.GetAllTransitiveGroupsForGroup(parent)
+	if err != nil {
+		return err
+	}
+	for _, group := range parentGroups {
+		if group.Name == child {
+			return fmt.Errorf("circular reference detected: group %s is already a parent of group %s", child, parent)
+		}
+	}
 	tx, err := s.db.Begin()
 	if err != nil {
 		return err
@@ -315,52 +322,61 @@
 	return availableGroups, nil
 }
 
-func (s *SQLiteStore) GetAllTransitiveGroupsForUser(user string) ([]string, error) {
-	directGroups, err := s.GetGroupsUserBelongsTo(user)
-	if err != nil {
+func (s *SQLiteStore) GetAllTransitiveGroupsForUser(user string) ([]Group, error) {
+	if groups, err := s.GetGroupsUserBelongsTo(user); err != nil {
 		return nil, err
+	} else {
+		visited := map[string]struct{}{}
+		return s.getAllParentGroupsRecursive(groups, visited)
 	}
-	allGroups := make(map[string]bool)
-	for _, group := range directGroups {
-		if err := s.getParentGroups(group.Name, allGroups); err != nil {
+}
+
+func (s *SQLiteStore) GetAllTransitiveGroupsForGroup(group string) ([]Group, error) {
+	if p, err := s.GetGroupsGroupBelongsTo(group); err != nil {
+		return nil, err
+	} else {
+		// Mark initial group as visited
+		visited := map[string]struct{}{
+			group: struct{}{},
+		}
+		return s.getAllParentGroupsRecursive(p, visited)
+	}
+}
+
+func (s *SQLiteStore) getAllParentGroupsRecursive(groups []Group, visited map[string]struct{}) ([]Group, error) {
+	var ret []Group
+	for _, g := range groups {
+		if _, ok := visited[g.Name]; ok {
+			continue
+		}
+		visited[g.Name] = struct{}{}
+		ret = append(ret, g)
+		if p, err := s.GetGroupsGroupBelongsTo(g.Name); err != nil {
 			return nil, err
+		} else if res, err := s.getAllParentGroupsRecursive(p, visited); err != nil {
+			return nil, err
+		} else {
+			ret = append(ret, res...)
 		}
 	}
-	var result []string
-	for group := range allGroups {
-		result = append(result, group)
-	}
-	return result, nil
+	return ret, nil
 }
 
-func (s *SQLiteStore) getParentGroups(group string, allGroups map[string]bool) error {
-	if allGroups[group] {
-		return nil
-	}
-	allGroups[group] = true
-	parentGroups, err := s.GetGroupsGroupBelongsTo(group)
-	if err != nil {
-		return err
-	}
-	for _, parentGroup := range parentGroups {
-		if err := s.getParentGroups(parentGroup, allGroups); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (s *SQLiteStore) GetGroupsGroupBelongsTo(group string) ([]string, error) {
-	query := "SELECT parent_group FROM group_to_group WHERE child_group = ?"
+func (s *SQLiteStore) GetGroupsGroupBelongsTo(group string) ([]Group, error) {
+	query := `
+        SELECT groups.name, groups.description
+        FROM groups
+        JOIN group_to_group ON groups.name = group_to_group.parent_group
+        WHERE group_to_group.child_group = ?`
 	rows, err := s.db.Query(query, group)
 	if err != nil {
 		return nil, err
 	}
 	defer rows.Close()
-	var parentGroups []string
+	var parentGroups []Group
 	for rows.Next() {
-		var parentGroup string
-		if err := rows.Scan(&parentGroup); err != nil {
+		var parentGroup Group
+		if err := rows.Scan(&parentGroup.Name, &parentGroup.Description); err != nil {
 			return nil, err
 		}
 		parentGroups = append(parentGroups, parentGroup)
@@ -371,9 +387,37 @@
 	return parentGroups, nil
 }
 
+func (s *SQLiteStore) GetDirectChildrenGroups(group string) ([]Group, error) {
+	query := `
+        SELECT groups.name, groups.description
+        FROM groups
+        JOIN group_to_group ON groups.name = group_to_group.child_group
+        WHERE group_to_group.parent_group = ?`
+	rows, err := s.db.Query(query, group)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+	var childrenGroups []Group
+	for rows.Next() {
+		var childGroup Group
+		if err := rows.Scan(&childGroup.Name, &childGroup.Description); err != nil {
+			return nil, err
+		}
+		childrenGroups = append(childrenGroups, childGroup)
+	}
+	if err := rows.Err(); err != nil {
+		return nil, err
+	}
+	return childrenGroups, nil
+}
+
 func getLoggedInUser(r *http.Request) (string, error) {
-	// TODO(dtabidze): should make a request to get loggedin user
-	return "tabo", nil
+	if user := r.Header.Get("X-User"); user != "" {
+		return user, nil
+	} else {
+		return "", fmt.Errorf("unauthenticated")
+	}
 }
 
 type Status int
@@ -445,12 +489,19 @@
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
+	transitiveGroups, err := s.store.GetAllTransitiveGroupsForUser(loggedInUser)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
 	data := struct {
 		OwnerGroups      []Group
 		MembershipGroups []Group
+		TransitiveGroups []Group
 	}{
 		OwnerGroups:      ownerGroups,
 		MembershipGroups: membershipGroups,
+		TransitiveGroups: transitiveGroups,
 	}
 	w.Header().Set("Content-Type", "text/html")
 	if err := tmpl.Execute(w, data); err != nil {
@@ -484,7 +535,11 @@
 }
 
 func (s *Server) groupHandler(w http.ResponseWriter, r *http.Request) {
-	// groupName := strings.TrimPrefix(r.URL.Path, "/group/")
+	_, err := getLoggedInUser(r)
+	if err != nil {
+		http.Error(w, "User Not Logged In", http.StatusUnauthorized)
+		return
+	}
 	vars := mux.Vars(r)
 	groupName := vars["group-name"]
 	tmpl, err := template.New("group").Parse(groupHTML)
@@ -512,18 +567,32 @@
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
+	transitiveGroups, err := s.store.GetAllTransitiveGroupsForGroup(groupName)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+	childGroups, err := s.store.GetDirectChildrenGroups(groupName)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
 	data := struct {
-		GroupName       string
-		Description     string
-		Owners          []string
-		Members         []string
-		AvailableGroups []string
+		GroupName        string
+		Description      string
+		Owners           []string
+		Members          []string
+		AvailableGroups  []string
+		TransitiveGroups []Group
+		ChildGroups      []Group
 	}{
-		GroupName:       groupName,
-		Description:     description,
-		Owners:          owners,
-		Members:         members,
-		AvailableGroups: availableGroups,
+		GroupName:        groupName,
+		Description:      description,
+		Owners:           owners,
+		Members:          members,
+		AvailableGroups:  availableGroups,
+		TransitiveGroups: transitiveGroups,
+		ChildGroups:      childGroups,
 	}
 	if err := tmpl.Execute(w, data); err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -606,8 +675,12 @@
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
+	var groupNames []string
+	for _, group := range transitiveGroups {
+		groupNames = append(groupNames, group.Name)
+	}
 	w.Header().Set("Content-Type", "application/json")
-	if err := json.NewEncoder(w).Encode(UserInfo{transitiveGroups}); err != nil {
+	if err := json.NewEncoder(w).Encode(UserInfo{MemberOf: groupNames}); err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
@@ -615,10 +688,14 @@
 
 func main() {
 	flag.Parse()
-	db, err := NewSQLiteStore(*dbPath)
+	db, err := sql.Open("sqlite3", *dbPath)
 	if err != nil {
 		panic(err)
 	}
-	s := Server{store: db}
+	store, err := NewSQLiteStore(db)
+	if err != nil {
+		panic(err)
+	}
+	s := Server{store}
 	s.Start()
 }
diff --git a/core/auth/memberships/store_test.go b/core/auth/memberships/store_test.go
new file mode 100644
index 0000000..eb774cc
--- /dev/null
+++ b/core/auth/memberships/store_test.go
@@ -0,0 +1,82 @@
+package main
+
+import (
+	"database/sql"
+	"testing"
+
+	_ "github.com/ncruces/go-sqlite3/driver"
+	_ "github.com/ncruces/go-sqlite3/embed"
+)
+
+func TestGetAllTransitiveGroupsForGroup(t *testing.T) {
+	db, err := sql.Open("sqlite3", ":memory:")
+	if err != nil {
+		t.Fatal(err)
+	}
+	store, err := NewSQLiteStore(db)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = db.Exec(`
+        INSERT INTO groups (name, description)
+        VALUES
+            ('a', 'xxx'),
+            ('b', 'yyy');
+
+        INSERT INTO group_to_group (child_group, parent_group)
+        VALUES
+            ('a', 'b'),
+            ('b', 'a');
+        `)
+	if err != nil {
+		t.Fatal(err)
+	}
+	groups, err := store.GetAllTransitiveGroupsForGroup("a")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(groups) != 1 {
+		t.Fatalf("Expected exactly one transitive group, got: %s", groups)
+	}
+	expected := Group{"b", "yyy"}
+	if groups[0] != expected {
+		t.Fatalf("Expected %s, got: %s", expected, groups[0])
+	}
+}
+
+func TestGetAllTransitiveGroupsForUser(t *testing.T) {
+	db, err := sql.Open("sqlite3", ":memory:")
+	if err != nil {
+		t.Fatal(err)
+	}
+	store, err := NewSQLiteStore(db)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = db.Exec(`
+        INSERT INTO groups (name, description)
+        VALUES
+            ('a', 'xxx'),
+            ('b', 'yyy'),
+            ('c', 'zzz');
+
+        INSERT INTO group_to_group (child_group, parent_group)
+        VALUES
+            ('a', 'c'),
+            ('b', 'c');
+        INSERT INTO user_to_group (username, group_name)
+        VALUES
+            ('u', 'a'),
+            ('u', 'b');
+        `)
+	if err != nil {
+		t.Fatal(err)
+	}
+	groups, err := store.GetAllTransitiveGroupsForUser("u")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(groups) != 3 {
+		t.Fatalf("Expected exactly one transitive group, got: %s", groups)
+	}
+}