Memberships: Filter all users by groups and transitive groups
Change-Id: I9766501e19a058b958578476b8586883655e453f
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 88729a2..8ef0848 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -55,7 +55,7 @@
RemoveFromGroupToGroup(parent, child string) error
RemoveUserFromTable(username, groupName, tableName string) error
GetAllGroups() ([]Group, error)
- GetAllUsers() ([]User, error)
+ GetUsers(username []string) ([]User, error)
GetUser(username string) (User, error)
AddSSHKeyForUser(username, sshKey string) error
RemoveSSHKeyForUser(username, sshKey string) error
@@ -558,13 +558,27 @@
return nil
}
-func (s *SQLiteStore) GetAllUsers() ([]User, error) {
- rows, err := s.db.Query(`
+func (s *SQLiteStore) GetUsers(usernames []string) ([]User, error) {
+ var rows *sql.Rows
+ var err error
+ query := `
SELECT users.username, users.email, GROUP_CONCAT(user_ssh_keys.ssh_key, ',')
FROM users
- LEFT JOIN user_ssh_keys ON users.username = user_ssh_keys.username
- GROUP BY users.username
- `)
+ LEFT JOIN user_ssh_keys ON users.username = user_ssh_keys.username`
+ var args []interface{}
+ if usernames != nil {
+ if len(usernames) == 0 {
+ return []User{}, nil
+ }
+ query += " WHERE users.username IN ("
+ placeholders := strings.Repeat("?,", len(usernames)-1) + "?"
+ query += placeholders + ") "
+ for _, username := range usernames {
+ args = append(args, username)
+ }
+ }
+ query += " GROUP BY users.username"
+ rows, err = s.db.Query(query, args...)
if err != nil {
return nil, err
}
@@ -1224,9 +1238,44 @@
if selfAddress != "" {
s.addSyncAddress(selfAddress)
}
- users, err := s.store.GetAllUsers()
+ var users []User
+ var err error
+ groups := r.FormValue("groups")
+ if groups == "" {
+ users, err = s.store.GetUsers(nil)
+ } else {
+ uniqueUsers := make(map[string]struct{})
+ g := strings.Split(groups, ",")
+ uniqueTG := make(map[string]struct{})
+ for _, group := range g {
+ uniqueTG[group] = struct{}{}
+ trGroups, err := s.store.GetAllTransitiveGroupsForGroup(group)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ for _, tg := range trGroups {
+ uniqueTG[tg.Name] = struct{}{}
+ }
+ }
+ for group := range uniqueTG {
+ u, err := s.store.GetGroupMembers(group)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ for _, user := range u {
+ uniqueUsers[user] = struct{}{}
+ }
+ }
+ usernames := make([]string, 0, len(uniqueUsers))
+ for username := range uniqueUsers {
+ usernames = append(usernames, username)
+ }
+ users, err = s.store.GetUsers(usernames)
+ }
if err != nil {
- http.Error(w, "Failed to retrieve SSH keys", http.StatusInternalServerError)
+ http.Error(w, "Failed to retrieve user infos", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
diff --git a/core/auth/memberships/store_test.go b/core/auth/memberships/store_test.go
index 980c719..4e9ef19 100644
--- a/core/auth/memberships/store_test.go
+++ b/core/auth/memberships/store_test.go
@@ -2,9 +2,10 @@
import (
"database/sql"
- "fmt"
+ "encoding/json"
"net/http"
"net/http/httptest"
+ "reflect"
"sync"
"testing"
@@ -222,9 +223,177 @@
if status := rr.Code; status != http.StatusSeeOther {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}
- body := rr.Body.String()
- fmt.Println("BODY: ", rr.Header().Get("Location"))
if rr.Header().Get("Location") != "/group/bb" {
- t.Errorf("handler returned unexpected body: got %v want %v", body, "expected body")
+ t.Errorf("handler returned wrong Location header: got %v want %v", rr.Header().Get("Location"), "/group/bb")
+ }
+}
+
+func TestFilterUsersByGroupHandler(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.Exec(`
+ CREATE TABLE IF NOT EXISTS groups (
+ name TEXT PRIMARY KEY,
+ description TEXT
+ );
+ INSERT INTO groups (name, description)
+ VALUES
+ ('a', 'a'),
+ ('b', 'b'),
+ ('c', 'c'),
+ ('d', 'd'),
+ ('e', 'e'),
+ ('f', 'f');
+ CREATE TABLE IF NOT EXISTS owners (
+ username TEXT,
+ group_name TEXT,
+ FOREIGN KEY(group_name) REFERENCES groups(name),
+ UNIQUE (username, group_name)
+ );
+ INSERT INTO owners (username, group_name)
+ VALUES
+ ('testuser1', 'a'),
+ ('testuser2', 'd');
+ CREATE TABLE IF NOT EXISTS group_to_group (
+ parent_group TEXT,
+ child_group TEXT
+ );
+ INSERT INTO group_to_group (parent_group, child_group)
+ VALUES
+ ('a', 'b'),
+ ('b', 'c'),
+ ('d', 'e'),
+ ('e', 'f');
+ CREATE TABLE IF NOT EXISTS user_to_group (
+ username TEXT,
+ group_name TEXT,
+ FOREIGN KEY(group_name) REFERENCES groups(name),
+ UNIQUE (username, group_name)
+ );
+ INSERT INTO user_to_group (username, group_name)
+ VALUES
+ ('u1', 'a'),
+ ('u2', 'b'),
+ ('u3', 'e'),
+ ('u4', 'f'),
+ ('u5', 'f'),
+ ('u6', 'd'),
+ ('u7', 'd');
+ CREATE TABLE IF NOT EXISTS users (
+ username TEXT PRIMARY KEY,
+ email TEXT,
+ UNIQUE (email)
+ );
+ INSERT INTO users (username, email)
+ VALUES
+ ('u1','u1@d.d'),
+ ('u2','u2@d.d'),
+ ('u3','u3@d.d'),
+ ('u4','u4@d.d'),
+ ('u5','u5@d.d'),
+ ('u6','u6@d.d'),
+ ('u7','u7@d.d');
+ CREATE TABLE IF NOT EXISTS user_ssh_keys (
+ username TEXT,
+ ssh_key TEXT,
+ UNIQUE (ssh_key),
+ FOREIGN KEY(username) REFERENCES users(username)
+ );
+ INSERT INTO user_ssh_keys (username, ssh_key)
+ VALUES
+ ('u1','ssh1'),
+ ('u1','ssh1-1'),
+ ('u2','ssh2'),
+ ('u3','ssh3'),
+ ('u4','ssh4'),
+ ('u5','ssh5'),
+ ('u6','ssh6'),
+ ('u7','ssh7');
+ `)
+ if err != nil {
+ t.Fatal(err)
+ }
+ server := &Server{
+ store: store,
+ syncAddresses: make(map[string]struct{}),
+ mu: sync.Mutex{},
+ }
+ router := mux.NewRouter()
+ // case when group present or exist
+ router.HandleFunc("/api/users", server.apiGetAllUsers).Methods(http.MethodGet)
+ req, err := http.NewRequest("GET", "/api/users?groups=b,e,t", nil)
+ req.Header.Set("X-User", "testuser1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr := httptest.NewRecorder()
+ router.ServeHTTP(rr, req)
+ expected := []User{
+ {"u1", "u1@d.d", []string{"ssh1", "ssh1-1"}},
+ {"u2", "u2@d.d", []string{"ssh2"}},
+ {"u3", "u3@d.d", []string{"ssh3"}},
+ {"u6", "u6@d.d", []string{"ssh6"}},
+ {"u7", "u7@d.d", []string{"ssh7"}},
+ }
+
+ var actual []User
+ err = json.NewDecoder(rr.Body).Decode(&actual)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(actual, expected) {
+ t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
+ }
+
+ // case when no group present
+ req, err = http.NewRequest("GET", "/api/users?groups=", nil)
+ req.Header.Set("X-User", "testuser1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr = httptest.NewRecorder()
+ router.ServeHTTP(rr, req)
+
+ expected = []User{
+ {"u1", "u1@d.d", []string{"ssh1", "ssh1-1"}},
+ {"u2", "u2@d.d", []string{"ssh2"}},
+ {"u3", "u3@d.d", []string{"ssh3"}},
+ {"u4", "u4@d.d", []string{"ssh4"}},
+ {"u5", "u5@d.d", []string{"ssh5"}},
+ {"u6", "u6@d.d", []string{"ssh6"}},
+ {"u7", "u7@d.d", []string{"ssh7"}},
+ }
+ err = json.NewDecoder(rr.Body).Decode(&actual)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(actual, expected) {
+ t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
+ }
+
+ // case when wrong groups
+ req, err = http.NewRequest("GET", "/api/users?groups=x,y", nil)
+ req.Header.Set("X-User", "testuser1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr = httptest.NewRecorder()
+ router.ServeHTTP(rr, req)
+
+ expected = []User{}
+ err = json.NewDecoder(rr.Body).Decode(&actual)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(actual, expected) {
+ t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
}
}