Memberships: Filter all users by groups and transitive groups
Change-Id: I9766501e19a058b958578476b8586883655e453f
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 88729a2..8ef0848 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -55,7 +55,7 @@
RemoveFromGroupToGroup(parent, child string) error
RemoveUserFromTable(username, groupName, tableName string) error
GetAllGroups() ([]Group, error)
- GetAllUsers() ([]User, error)
+ GetUsers(username []string) ([]User, error)
GetUser(username string) (User, error)
AddSSHKeyForUser(username, sshKey string) error
RemoveSSHKeyForUser(username, sshKey string) error
@@ -558,13 +558,27 @@
return nil
}
-func (s *SQLiteStore) GetAllUsers() ([]User, error) {
- rows, err := s.db.Query(`
+func (s *SQLiteStore) GetUsers(usernames []string) ([]User, error) {
+ var rows *sql.Rows
+ var err error
+ query := `
SELECT users.username, users.email, GROUP_CONCAT(user_ssh_keys.ssh_key, ',')
FROM users
- LEFT JOIN user_ssh_keys ON users.username = user_ssh_keys.username
- GROUP BY users.username
- `)
+ LEFT JOIN user_ssh_keys ON users.username = user_ssh_keys.username`
+ var args []interface{}
+ if usernames != nil {
+ if len(usernames) == 0 {
+ return []User{}, nil
+ }
+ query += " WHERE users.username IN ("
+ placeholders := strings.Repeat("?,", len(usernames)-1) + "?"
+ query += placeholders + ") "
+ for _, username := range usernames {
+ args = append(args, username)
+ }
+ }
+ query += " GROUP BY users.username"
+ rows, err = s.db.Query(query, args...)
if err != nil {
return nil, err
}
@@ -1224,9 +1238,44 @@
if selfAddress != "" {
s.addSyncAddress(selfAddress)
}
- users, err := s.store.GetAllUsers()
+ var users []User
+ var err error
+ groups := r.FormValue("groups")
+ if groups == "" {
+ users, err = s.store.GetUsers(nil)
+ } else {
+ uniqueUsers := make(map[string]struct{})
+ g := strings.Split(groups, ",")
+ uniqueTG := make(map[string]struct{})
+ for _, group := range g {
+ uniqueTG[group] = struct{}{}
+ trGroups, err := s.store.GetAllTransitiveGroupsForGroup(group)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ for _, tg := range trGroups {
+ uniqueTG[tg.Name] = struct{}{}
+ }
+ }
+ for group := range uniqueTG {
+ u, err := s.store.GetGroupMembers(group)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ for _, user := range u {
+ uniqueUsers[user] = struct{}{}
+ }
+ }
+ usernames := make([]string, 0, len(uniqueUsers))
+ for username := range uniqueUsers {
+ usernames = append(usernames, username)
+ }
+ users, err = s.store.GetUsers(usernames)
+ }
if err != nil {
- http.Error(w, "Failed to retrieve SSH keys", http.StatusInternalServerError)
+ http.Error(w, "Failed to retrieve user infos", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")