membership: list child and parent groups (#107)
* Group page shows child and parent groups
* moved recursion as a helper
* reworked several funcs to return directly Group type
* added TransitiveGroups on homepage
* added circular reference detection
* improved performance of transitive groups
* minor fix
* GetAllTransitiveGroupsForGroup parameter changed to string
* cover getting transitive groups with unit test
* recursion logic needs fix
* refactor: reuse code for processing user/group transitive groups
* recursion fix
---------
Co-authored-by: Giorgi Lekveishvili <lekva@gl-mbp-m1-max.local>
diff --git a/core/auth/memberships/.gitignore b/core/auth/memberships/.gitignore
index 4042c7b..6ce208d 100644
--- a/core/auth/memberships/.gitignore
+++ b/core/auth/memberships/.gitignore
@@ -1,2 +1,2 @@
-# Exclude SQLite database file
*.db
+memberships*
diff --git a/core/auth/memberships/Makefile b/core/auth/memberships/Makefile
index d846cd4..1defe7d 100644
--- a/core/auth/memberships/Makefile
+++ b/core/auth/memberships/Makefile
@@ -9,25 +9,31 @@
clean:
rm -f memberships*
+build: clean
+ /usr/local/go/bin/go build -o memberships *.go
+
+test:
+ /usr/local/go/bin/go test ./...
+
build_arm64: export CGO_ENABLED=0
build_arm64: export GO111MODULE=on
build_arm64: export GOOS=linux
build_arm64: export GOARCH=arm64
-build_arm64:
+build_arm64: clean
/usr/local/go/bin/go build -o memberships_arm64 *.go
build_amd64: export CGO_ENABLED=0
build_amd64: export GO111MODULE=on
build_amd64: export GOOS=linux
build_amd64: export GOARCH=amd64
-build_amd64:
+build_amd64: clean
/usr/local/go/bin/go build -o memberships_amd64 *.go
-push_arm64: clean build_arm64
+push_arm64: build_arm64
$(podman) build --platform linux/arm64 --tag=$(repo_name)/memberships:arm64 .
$(podman) push $(repo_name)/memberships:arm64
-push_amd64: clean build_amd64
+push_amd64: build_amd64
$(podman) build --platform linux/amd64 --tag=$(repo_name)/memberships:amd64 .
$(podman) push $(repo_name)/memberships:amd64
diff --git a/core/auth/memberships/group.html b/core/auth/memberships/group.html
index 979777f..ac432d8 100644
--- a/core/auth/memberships/group.html
+++ b/core/auth/memberships/group.html
@@ -61,5 +61,31 @@
</tr>
{{- end }}
</table>
+ <h4>Transitive Groups</h4>
+ <table>
+ <tr>
+ <th>Group Name</th>
+ <th>Description</th>
+ </tr>
+ {{- range .TransitiveGroups }}
+ <tr>
+ <td><a href="/group/{{ .Name }}">{{ .Name }}</a></td>
+ <td>{{ .Description }}</td>
+ </tr>
+ {{- end }}
+ </table>
+ <h3>Child Groups</h3>
+ <table>
+ <tr>
+ <th>Group Name</th>
+ <th>Description</th>
+ </tr>
+ {{- range .ChildGroups }}
+ <tr>
+ <td><a href="/group/{{ .Name }}">{{ .Name }}</a></td>
+ <td>{{ .Description }}</td>
+ </tr>
+ {{- end }}
+ </table>
</body>
</html>
diff --git a/core/auth/memberships/index.html b/core/auth/memberships/index.html
index f78ae1d..03ee299 100644
--- a/core/auth/memberships/index.html
+++ b/core/auth/memberships/index.html
@@ -41,5 +41,18 @@
</tr>
{{- end -}}
</table>
+ <h4>Transitive Groups</h4>
+ <table>
+ <tr>
+ <th>Name</th>
+ <th>Description</th>
+ </tr>
+ {{- range .TransitiveGroups -}}
+ <tr>
+ <td><a href="/group/{{ .Name }}">{{ .Name }}</a></td>
+ <td>{{ .Description }}</td>
+ </tr>
+ {{- end -}}
+ </table>
</body>
</html>
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 0c1d104..4b64140 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -41,8 +41,10 @@
GetGroupMembers(group string) ([]string, error)
GetGroupDescription(group string) (string, error)
GetAvailableGroupsAsChild(group string) ([]string, error)
- GetAllTransitiveGroupsForUser(user string) ([]string, error)
- GetGroupsGroupBelongsTo(group string) ([]string, error)
+ GetAllTransitiveGroupsForUser(user string) ([]Group, error)
+ GetGroupsGroupBelongsTo(group string) ([]Group, error)
+ GetDirectChildrenGroups(group string) ([]Group, error)
+ GetAllTransitiveGroupsForGroup(group string) ([]Group, error)
}
type Server struct {
@@ -58,12 +60,8 @@
db *sql.DB
}
-func NewSQLiteStore(path string) (*SQLiteStore, error) {
- db, err := sql.Open("sqlite3", path)
- if err != nil {
- return nil, err
- }
- _, err = db.Exec(`
+func NewSQLiteStore(db *sql.DB) (*SQLiteStore, error) {
+ _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS groups (
name TEXT PRIMARY KEY,
description TEXT
@@ -270,6 +268,15 @@
}
func (s *SQLiteStore) AddChildGroup(parent, child string) error {
+ parentGroups, err := s.GetAllTransitiveGroupsForGroup(parent)
+ if err != nil {
+ return err
+ }
+ for _, group := range parentGroups {
+ if group.Name == child {
+ return fmt.Errorf("circular reference detected: group %s is already a parent of group %s", child, parent)
+ }
+ }
tx, err := s.db.Begin()
if err != nil {
return err
@@ -315,52 +322,61 @@
return availableGroups, nil
}
-func (s *SQLiteStore) GetAllTransitiveGroupsForUser(user string) ([]string, error) {
- directGroups, err := s.GetGroupsUserBelongsTo(user)
- if err != nil {
+func (s *SQLiteStore) GetAllTransitiveGroupsForUser(user string) ([]Group, error) {
+ if groups, err := s.GetGroupsUserBelongsTo(user); err != nil {
return nil, err
+ } else {
+ visited := map[string]struct{}{}
+ return s.getAllParentGroupsRecursive(groups, visited)
}
- allGroups := make(map[string]bool)
- for _, group := range directGroups {
- if err := s.getParentGroups(group.Name, allGroups); err != nil {
+}
+
+func (s *SQLiteStore) GetAllTransitiveGroupsForGroup(group string) ([]Group, error) {
+ if p, err := s.GetGroupsGroupBelongsTo(group); err != nil {
+ return nil, err
+ } else {
+ // Mark initial group as visited
+ visited := map[string]struct{}{
+ group: struct{}{},
+ }
+ return s.getAllParentGroupsRecursive(p, visited)
+ }
+}
+
+func (s *SQLiteStore) getAllParentGroupsRecursive(groups []Group, visited map[string]struct{}) ([]Group, error) {
+ var ret []Group
+ for _, g := range groups {
+ if _, ok := visited[g.Name]; ok {
+ continue
+ }
+ visited[g.Name] = struct{}{}
+ ret = append(ret, g)
+ if p, err := s.GetGroupsGroupBelongsTo(g.Name); err != nil {
return nil, err
+ } else if res, err := s.getAllParentGroupsRecursive(p, visited); err != nil {
+ return nil, err
+ } else {
+ ret = append(ret, res...)
}
}
- var result []string
- for group := range allGroups {
- result = append(result, group)
- }
- return result, nil
+ return ret, nil
}
-func (s *SQLiteStore) getParentGroups(group string, allGroups map[string]bool) error {
- if allGroups[group] {
- return nil
- }
- allGroups[group] = true
- parentGroups, err := s.GetGroupsGroupBelongsTo(group)
- if err != nil {
- return err
- }
- for _, parentGroup := range parentGroups {
- if err := s.getParentGroups(parentGroup, allGroups); err != nil {
- return err
- }
- }
- return nil
-}
-
-func (s *SQLiteStore) GetGroupsGroupBelongsTo(group string) ([]string, error) {
- query := "SELECT parent_group FROM group_to_group WHERE child_group = ?"
+func (s *SQLiteStore) GetGroupsGroupBelongsTo(group string) ([]Group, error) {
+ query := `
+ SELECT groups.name, groups.description
+ FROM groups
+ JOIN group_to_group ON groups.name = group_to_group.parent_group
+ WHERE group_to_group.child_group = ?`
rows, err := s.db.Query(query, group)
if err != nil {
return nil, err
}
defer rows.Close()
- var parentGroups []string
+ var parentGroups []Group
for rows.Next() {
- var parentGroup string
- if err := rows.Scan(&parentGroup); err != nil {
+ var parentGroup Group
+ if err := rows.Scan(&parentGroup.Name, &parentGroup.Description); err != nil {
return nil, err
}
parentGroups = append(parentGroups, parentGroup)
@@ -371,9 +387,37 @@
return parentGroups, nil
}
+func (s *SQLiteStore) GetDirectChildrenGroups(group string) ([]Group, error) {
+ query := `
+ SELECT groups.name, groups.description
+ FROM groups
+ JOIN group_to_group ON groups.name = group_to_group.child_group
+ WHERE group_to_group.parent_group = ?`
+ rows, err := s.db.Query(query, group)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var childrenGroups []Group
+ for rows.Next() {
+ var childGroup Group
+ if err := rows.Scan(&childGroup.Name, &childGroup.Description); err != nil {
+ return nil, err
+ }
+ childrenGroups = append(childrenGroups, childGroup)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return childrenGroups, nil
+}
+
func getLoggedInUser(r *http.Request) (string, error) {
- // TODO(dtabidze): should make a request to get loggedin user
- return "tabo", nil
+ if user := r.Header.Get("X-User"); user != "" {
+ return user, nil
+ } else {
+ return "", fmt.Errorf("unauthenticated")
+ }
}
type Status int
@@ -445,12 +489,19 @@
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ transitiveGroups, err := s.store.GetAllTransitiveGroupsForUser(loggedInUser)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
data := struct {
OwnerGroups []Group
MembershipGroups []Group
+ TransitiveGroups []Group
}{
OwnerGroups: ownerGroups,
MembershipGroups: membershipGroups,
+ TransitiveGroups: transitiveGroups,
}
w.Header().Set("Content-Type", "text/html")
if err := tmpl.Execute(w, data); err != nil {
@@ -484,7 +535,11 @@
}
func (s *Server) groupHandler(w http.ResponseWriter, r *http.Request) {
- // groupName := strings.TrimPrefix(r.URL.Path, "/group/")
+ _, err := getLoggedInUser(r)
+ if err != nil {
+ http.Error(w, "User Not Logged In", http.StatusUnauthorized)
+ return
+ }
vars := mux.Vars(r)
groupName := vars["group-name"]
tmpl, err := template.New("group").Parse(groupHTML)
@@ -512,18 +567,32 @@
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ transitiveGroups, err := s.store.GetAllTransitiveGroupsForGroup(groupName)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ childGroups, err := s.store.GetDirectChildrenGroups(groupName)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
data := struct {
- GroupName string
- Description string
- Owners []string
- Members []string
- AvailableGroups []string
+ GroupName string
+ Description string
+ Owners []string
+ Members []string
+ AvailableGroups []string
+ TransitiveGroups []Group
+ ChildGroups []Group
}{
- GroupName: groupName,
- Description: description,
- Owners: owners,
- Members: members,
- AvailableGroups: availableGroups,
+ GroupName: groupName,
+ Description: description,
+ Owners: owners,
+ Members: members,
+ AvailableGroups: availableGroups,
+ TransitiveGroups: transitiveGroups,
+ ChildGroups: childGroups,
}
if err := tmpl.Execute(w, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -606,8 +675,12 @@
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ var groupNames []string
+ for _, group := range transitiveGroups {
+ groupNames = append(groupNames, group.Name)
+ }
w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(UserInfo{transitiveGroups}); err != nil {
+ if err := json.NewEncoder(w).Encode(UserInfo{MemberOf: groupNames}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -615,10 +688,14 @@
func main() {
flag.Parse()
- db, err := NewSQLiteStore(*dbPath)
+ db, err := sql.Open("sqlite3", *dbPath)
if err != nil {
panic(err)
}
- s := Server{store: db}
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ panic(err)
+ }
+ s := Server{store}
s.Start()
}
diff --git a/core/auth/memberships/store_test.go b/core/auth/memberships/store_test.go
new file mode 100644
index 0000000..eb774cc
--- /dev/null
+++ b/core/auth/memberships/store_test.go
@@ -0,0 +1,82 @@
+package main
+
+import (
+ "database/sql"
+ "testing"
+
+ _ "github.com/ncruces/go-sqlite3/driver"
+ _ "github.com/ncruces/go-sqlite3/embed"
+)
+
+func TestGetAllTransitiveGroupsForGroup(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.Exec(`
+ INSERT INTO groups (name, description)
+ VALUES
+ ('a', 'xxx'),
+ ('b', 'yyy');
+
+ INSERT INTO group_to_group (child_group, parent_group)
+ VALUES
+ ('a', 'b'),
+ ('b', 'a');
+ `)
+ if err != nil {
+ t.Fatal(err)
+ }
+ groups, err := store.GetAllTransitiveGroupsForGroup("a")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(groups) != 1 {
+ t.Fatalf("Expected exactly one transitive group, got: %s", groups)
+ }
+ expected := Group{"b", "yyy"}
+ if groups[0] != expected {
+ t.Fatalf("Expected %s, got: %s", expected, groups[0])
+ }
+}
+
+func TestGetAllTransitiveGroupsForUser(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.Exec(`
+ INSERT INTO groups (name, description)
+ VALUES
+ ('a', 'xxx'),
+ ('b', 'yyy'),
+ ('c', 'zzz');
+
+ INSERT INTO group_to_group (child_group, parent_group)
+ VALUES
+ ('a', 'c'),
+ ('b', 'c');
+ INSERT INTO user_to_group (username, group_name)
+ VALUES
+ ('u', 'a'),
+ ('u', 'b');
+ `)
+ if err != nil {
+ t.Fatal(err)
+ }
+ groups, err := store.GetAllTransitiveGroupsForUser("u")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(groups) != 3 {
+ t.Fatalf("Expected exactly one transitive group, got: %s", groups)
+ }
+}