membership: list child and parent groups (#107)
* Group page shows child and parent groups
* moved recursion as a helper
* reworked several funcs to return directly Group type
* added TransitiveGroups on homepage
* added circular reference detection
* improved performance of transitive groups
* minor fix
* GetAllTransitiveGroupsForGroup parameter changed to string
* cover getting transitive groups with unit test
* recursion logic needs fix
* refactor: reuse code for processing user/group transitive groups
* recursion fix
---------
Co-authored-by: Giorgi Lekveishvili <lekva@gl-mbp-m1-max.local>
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 0c1d104..4b64140 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -41,8 +41,10 @@
GetGroupMembers(group string) ([]string, error)
GetGroupDescription(group string) (string, error)
GetAvailableGroupsAsChild(group string) ([]string, error)
- GetAllTransitiveGroupsForUser(user string) ([]string, error)
- GetGroupsGroupBelongsTo(group string) ([]string, error)
+ GetAllTransitiveGroupsForUser(user string) ([]Group, error)
+ GetGroupsGroupBelongsTo(group string) ([]Group, error)
+ GetDirectChildrenGroups(group string) ([]Group, error)
+ GetAllTransitiveGroupsForGroup(group string) ([]Group, error)
}
type Server struct {
@@ -58,12 +60,8 @@
db *sql.DB
}
-func NewSQLiteStore(path string) (*SQLiteStore, error) {
- db, err := sql.Open("sqlite3", path)
- if err != nil {
- return nil, err
- }
- _, err = db.Exec(`
+func NewSQLiteStore(db *sql.DB) (*SQLiteStore, error) {
+ _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS groups (
name TEXT PRIMARY KEY,
description TEXT
@@ -270,6 +268,15 @@
}
func (s *SQLiteStore) AddChildGroup(parent, child string) error {
+ parentGroups, err := s.GetAllTransitiveGroupsForGroup(parent)
+ if err != nil {
+ return err
+ }
+ for _, group := range parentGroups {
+ if group.Name == child {
+ return fmt.Errorf("circular reference detected: group %s is already a parent of group %s", child, parent)
+ }
+ }
tx, err := s.db.Begin()
if err != nil {
return err
@@ -315,52 +322,61 @@
return availableGroups, nil
}
-func (s *SQLiteStore) GetAllTransitiveGroupsForUser(user string) ([]string, error) {
- directGroups, err := s.GetGroupsUserBelongsTo(user)
- if err != nil {
+func (s *SQLiteStore) GetAllTransitiveGroupsForUser(user string) ([]Group, error) {
+ if groups, err := s.GetGroupsUserBelongsTo(user); err != nil {
return nil, err
+ } else {
+ visited := map[string]struct{}{}
+ return s.getAllParentGroupsRecursive(groups, visited)
}
- allGroups := make(map[string]bool)
- for _, group := range directGroups {
- if err := s.getParentGroups(group.Name, allGroups); err != nil {
+}
+
+func (s *SQLiteStore) GetAllTransitiveGroupsForGroup(group string) ([]Group, error) {
+ if p, err := s.GetGroupsGroupBelongsTo(group); err != nil {
+ return nil, err
+ } else {
+ // Mark initial group as visited
+ visited := map[string]struct{}{
+ group: struct{}{},
+ }
+ return s.getAllParentGroupsRecursive(p, visited)
+ }
+}
+
+func (s *SQLiteStore) getAllParentGroupsRecursive(groups []Group, visited map[string]struct{}) ([]Group, error) {
+ var ret []Group
+ for _, g := range groups {
+ if _, ok := visited[g.Name]; ok {
+ continue
+ }
+ visited[g.Name] = struct{}{}
+ ret = append(ret, g)
+ if p, err := s.GetGroupsGroupBelongsTo(g.Name); err != nil {
return nil, err
+ } else if res, err := s.getAllParentGroupsRecursive(p, visited); err != nil {
+ return nil, err
+ } else {
+ ret = append(ret, res...)
}
}
- var result []string
- for group := range allGroups {
- result = append(result, group)
- }
- return result, nil
+ return ret, nil
}
-func (s *SQLiteStore) getParentGroups(group string, allGroups map[string]bool) error {
- if allGroups[group] {
- return nil
- }
- allGroups[group] = true
- parentGroups, err := s.GetGroupsGroupBelongsTo(group)
- if err != nil {
- return err
- }
- for _, parentGroup := range parentGroups {
- if err := s.getParentGroups(parentGroup, allGroups); err != nil {
- return err
- }
- }
- return nil
-}
-
-func (s *SQLiteStore) GetGroupsGroupBelongsTo(group string) ([]string, error) {
- query := "SELECT parent_group FROM group_to_group WHERE child_group = ?"
+func (s *SQLiteStore) GetGroupsGroupBelongsTo(group string) ([]Group, error) {
+ query := `
+ SELECT groups.name, groups.description
+ FROM groups
+ JOIN group_to_group ON groups.name = group_to_group.parent_group
+ WHERE group_to_group.child_group = ?`
rows, err := s.db.Query(query, group)
if err != nil {
return nil, err
}
defer rows.Close()
- var parentGroups []string
+ var parentGroups []Group
for rows.Next() {
- var parentGroup string
- if err := rows.Scan(&parentGroup); err != nil {
+ var parentGroup Group
+ if err := rows.Scan(&parentGroup.Name, &parentGroup.Description); err != nil {
return nil, err
}
parentGroups = append(parentGroups, parentGroup)
@@ -371,9 +387,37 @@
return parentGroups, nil
}
+func (s *SQLiteStore) GetDirectChildrenGroups(group string) ([]Group, error) {
+ query := `
+ SELECT groups.name, groups.description
+ FROM groups
+ JOIN group_to_group ON groups.name = group_to_group.child_group
+ WHERE group_to_group.parent_group = ?`
+ rows, err := s.db.Query(query, group)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var childrenGroups []Group
+ for rows.Next() {
+ var childGroup Group
+ if err := rows.Scan(&childGroup.Name, &childGroup.Description); err != nil {
+ return nil, err
+ }
+ childrenGroups = append(childrenGroups, childGroup)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return childrenGroups, nil
+}
+
func getLoggedInUser(r *http.Request) (string, error) {
- // TODO(dtabidze): should make a request to get loggedin user
- return "tabo", nil
+ if user := r.Header.Get("X-User"); user != "" {
+ return user, nil
+ } else {
+ return "", fmt.Errorf("unauthenticated")
+ }
}
type Status int
@@ -445,12 +489,19 @@
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ transitiveGroups, err := s.store.GetAllTransitiveGroupsForUser(loggedInUser)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
data := struct {
OwnerGroups []Group
MembershipGroups []Group
+ TransitiveGroups []Group
}{
OwnerGroups: ownerGroups,
MembershipGroups: membershipGroups,
+ TransitiveGroups: transitiveGroups,
}
w.Header().Set("Content-Type", "text/html")
if err := tmpl.Execute(w, data); err != nil {
@@ -484,7 +535,11 @@
}
func (s *Server) groupHandler(w http.ResponseWriter, r *http.Request) {
- // groupName := strings.TrimPrefix(r.URL.Path, "/group/")
+ _, err := getLoggedInUser(r)
+ if err != nil {
+ http.Error(w, "User Not Logged In", http.StatusUnauthorized)
+ return
+ }
vars := mux.Vars(r)
groupName := vars["group-name"]
tmpl, err := template.New("group").Parse(groupHTML)
@@ -512,18 +567,32 @@
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ transitiveGroups, err := s.store.GetAllTransitiveGroupsForGroup(groupName)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ childGroups, err := s.store.GetDirectChildrenGroups(groupName)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
data := struct {
- GroupName string
- Description string
- Owners []string
- Members []string
- AvailableGroups []string
+ GroupName string
+ Description string
+ Owners []string
+ Members []string
+ AvailableGroups []string
+ TransitiveGroups []Group
+ ChildGroups []Group
}{
- GroupName: groupName,
- Description: description,
- Owners: owners,
- Members: members,
- AvailableGroups: availableGroups,
+ GroupName: groupName,
+ Description: description,
+ Owners: owners,
+ Members: members,
+ AvailableGroups: availableGroups,
+ TransitiveGroups: transitiveGroups,
+ ChildGroups: childGroups,
}
if err := tmpl.Execute(w, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -606,8 +675,12 @@
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ var groupNames []string
+ for _, group := range transitiveGroups {
+ groupNames = append(groupNames, group.Name)
+ }
w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(UserInfo{transitiveGroups}); err != nil {
+ if err := json.NewEncoder(w).Encode(UserInfo{MemberOf: groupNames}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -615,10 +688,14 @@
func main() {
flag.Parse()
- db, err := NewSQLiteStore(*dbPath)
+ db, err := sql.Open("sqlite3", *dbPath)
if err != nil {
panic(err)
}
- s := Server{store: db}
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ panic(err)
+ }
+ s := Server{store}
s.Start()
}