Memberships: Filter all users by groups and transitive groups

Change-Id: I9766501e19a058b958578476b8586883655e453f
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 88729a2..8ef0848 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -55,7 +55,7 @@
 	RemoveFromGroupToGroup(parent, child string) error
 	RemoveUserFromTable(username, groupName, tableName string) error
 	GetAllGroups() ([]Group, error)
-	GetAllUsers() ([]User, error)
+	GetUsers(username []string) ([]User, error)
 	GetUser(username string) (User, error)
 	AddSSHKeyForUser(username, sshKey string) error
 	RemoveSSHKeyForUser(username, sshKey string) error
@@ -558,13 +558,27 @@
 	return nil
 }
 
-func (s *SQLiteStore) GetAllUsers() ([]User, error) {
-	rows, err := s.db.Query(`
+func (s *SQLiteStore) GetUsers(usernames []string) ([]User, error) {
+	var rows *sql.Rows
+	var err error
+	query := `
 		SELECT users.username, users.email, GROUP_CONCAT(user_ssh_keys.ssh_key, ',')
 		FROM users
-		LEFT JOIN user_ssh_keys ON users.username = user_ssh_keys.username
-		GROUP BY users.username
-	`)
+		LEFT JOIN user_ssh_keys ON users.username = user_ssh_keys.username`
+	var args []interface{}
+	if usernames != nil {
+		if len(usernames) == 0 {
+			return []User{}, nil
+		}
+		query += " WHERE users.username IN ("
+		placeholders := strings.Repeat("?,", len(usernames)-1) + "?"
+		query += placeholders + ") "
+		for _, username := range usernames {
+			args = append(args, username)
+		}
+	}
+	query += " GROUP BY users.username"
+	rows, err = s.db.Query(query, args...)
 	if err != nil {
 		return nil, err
 	}
@@ -1224,9 +1238,44 @@
 	if selfAddress != "" {
 		s.addSyncAddress(selfAddress)
 	}
-	users, err := s.store.GetAllUsers()
+	var users []User
+	var err error
+	groups := r.FormValue("groups")
+	if groups == "" {
+		users, err = s.store.GetUsers(nil)
+	} else {
+		uniqueUsers := make(map[string]struct{})
+		g := strings.Split(groups, ",")
+		uniqueTG := make(map[string]struct{})
+		for _, group := range g {
+			uniqueTG[group] = struct{}{}
+			trGroups, err := s.store.GetAllTransitiveGroupsForGroup(group)
+			if err != nil {
+				http.Error(w, err.Error(), http.StatusInternalServerError)
+				return
+			}
+			for _, tg := range trGroups {
+				uniqueTG[tg.Name] = struct{}{}
+			}
+		}
+		for group := range uniqueTG {
+			u, err := s.store.GetGroupMembers(group)
+			if err != nil {
+				http.Error(w, err.Error(), http.StatusInternalServerError)
+				return
+			}
+			for _, user := range u {
+				uniqueUsers[user] = struct{}{}
+			}
+		}
+		usernames := make([]string, 0, len(uniqueUsers))
+		for username := range uniqueUsers {
+			usernames = append(usernames, username)
+		}
+		users, err = s.store.GetUsers(usernames)
+	}
 	if err != nil {
-		http.Error(w, "Failed to retrieve SSH keys", http.StatusInternalServerError)
+		http.Error(w, "Failed to retrieve user infos", http.StatusInternalServerError)
 		return
 	}
 	w.Header().Set("Content-Type", "application/json")
diff --git a/core/auth/memberships/store_test.go b/core/auth/memberships/store_test.go
index 980c719..4e9ef19 100644
--- a/core/auth/memberships/store_test.go
+++ b/core/auth/memberships/store_test.go
@@ -2,9 +2,10 @@
 
 import (
 	"database/sql"
-	"fmt"
+	"encoding/json"
 	"net/http"
 	"net/http/httptest"
+	"reflect"
 	"sync"
 	"testing"
 
@@ -222,9 +223,177 @@
 	if status := rr.Code; status != http.StatusSeeOther {
 		t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
 	}
-	body := rr.Body.String()
-	fmt.Println("BODY: ", rr.Header().Get("Location"))
 	if rr.Header().Get("Location") != "/group/bb" {
-		t.Errorf("handler returned unexpected body: got %v want %v", body, "expected body")
+		t.Errorf("handler returned wrong Location header: got %v want %v", rr.Header().Get("Location"), "/group/bb")
+	}
+}
+
+func TestFilterUsersByGroupHandler(t *testing.T) {
+	db, err := sql.Open("sqlite3", ":memory:")
+	if err != nil {
+		t.Fatal(err)
+	}
+	store, err := NewSQLiteStore(db)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = db.Exec(`
+		CREATE TABLE IF NOT EXISTS groups (
+			name TEXT PRIMARY KEY,
+			description TEXT
+		);
+		INSERT INTO groups (name, description)
+        VALUES
+            ('a', 'a'),
+			('b', 'b'),
+			('c', 'c'),
+			('d', 'd'),
+			('e', 'e'),
+			('f', 'f');
+		CREATE TABLE IF NOT EXISTS owners (
+			username TEXT,
+			group_name TEXT,
+			FOREIGN KEY(group_name) REFERENCES groups(name),
+			UNIQUE (username, group_name)
+		);
+		INSERT INTO owners (username, group_name)
+        VALUES
+            ('testuser1', 'a'),
+			('testuser2', 'd');
+		CREATE TABLE IF NOT EXISTS group_to_group (
+			parent_group TEXT,
+			child_group TEXT
+		);
+        INSERT INTO group_to_group (parent_group, child_group)
+        VALUES
+            ('a', 'b'),
+			('b', 'c'),
+			('d', 'e'),
+			('e', 'f');
+		CREATE TABLE IF NOT EXISTS user_to_group (
+			username TEXT,
+			group_name TEXT,
+			FOREIGN KEY(group_name) REFERENCES groups(name),
+			UNIQUE (username, group_name)
+		);
+        INSERT INTO user_to_group (username, group_name)
+        VALUES
+            ('u1', 'a'),
+			('u2', 'b'),
+			('u3', 'e'),
+			('u4', 'f'),
+			('u5', 'f'),
+			('u6', 'd'),
+			('u7', 'd');
+		CREATE TABLE IF NOT EXISTS users (
+			username TEXT PRIMARY KEY,
+			email TEXT,
+			UNIQUE (email)
+		);
+		INSERT INTO users (username, email)
+		VALUES
+			('u1','u1@d.d'),
+			('u2','u2@d.d'),
+			('u3','u3@d.d'),
+			('u4','u4@d.d'),
+			('u5','u5@d.d'),
+			('u6','u6@d.d'),
+			('u7','u7@d.d');
+		CREATE TABLE IF NOT EXISTS user_ssh_keys (
+			username TEXT,
+			ssh_key TEXT,
+			UNIQUE (ssh_key),
+			FOREIGN KEY(username) REFERENCES users(username)
+		);
+		INSERT INTO user_ssh_keys (username, ssh_key)
+		VALUES
+			('u1','ssh1'),
+			('u1','ssh1-1'),
+			('u2','ssh2'),
+			('u3','ssh3'),
+			('u4','ssh4'),
+			('u5','ssh5'),
+			('u6','ssh6'),
+			('u7','ssh7');
+        `)
+	if err != nil {
+		t.Fatal(err)
+	}
+	server := &Server{
+		store:         store,
+		syncAddresses: make(map[string]struct{}),
+		mu:            sync.Mutex{},
+	}
+	router := mux.NewRouter()
+	// case when group present or exist
+	router.HandleFunc("/api/users", server.apiGetAllUsers).Methods(http.MethodGet)
+	req, err := http.NewRequest("GET", "/api/users?groups=b,e,t", nil)
+	req.Header.Set("X-User", "testuser1")
+	if err != nil {
+		t.Fatal(err)
+	}
+	rr := httptest.NewRecorder()
+	router.ServeHTTP(rr, req)
+	expected := []User{
+		{"u1", "u1@d.d", []string{"ssh1", "ssh1-1"}},
+		{"u2", "u2@d.d", []string{"ssh2"}},
+		{"u3", "u3@d.d", []string{"ssh3"}},
+		{"u6", "u6@d.d", []string{"ssh6"}},
+		{"u7", "u7@d.d", []string{"ssh7"}},
+	}
+
+	var actual []User
+	err = json.NewDecoder(rr.Body).Decode(&actual)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !reflect.DeepEqual(actual, expected) {
+		t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
+	}
+
+	// case when no group present
+	req, err = http.NewRequest("GET", "/api/users?groups=", nil)
+	req.Header.Set("X-User", "testuser1")
+	if err != nil {
+		t.Fatal(err)
+	}
+	rr = httptest.NewRecorder()
+	router.ServeHTTP(rr, req)
+
+	expected = []User{
+		{"u1", "u1@d.d", []string{"ssh1", "ssh1-1"}},
+		{"u2", "u2@d.d", []string{"ssh2"}},
+		{"u3", "u3@d.d", []string{"ssh3"}},
+		{"u4", "u4@d.d", []string{"ssh4"}},
+		{"u5", "u5@d.d", []string{"ssh5"}},
+		{"u6", "u6@d.d", []string{"ssh6"}},
+		{"u7", "u7@d.d", []string{"ssh7"}},
+	}
+	err = json.NewDecoder(rr.Body).Decode(&actual)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if !reflect.DeepEqual(actual, expected) {
+		t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
+	}
+
+	// case when wrong groups
+	req, err = http.NewRequest("GET", "/api/users?groups=x,y", nil)
+	req.Header.Set("X-User", "testuser1")
+	if err != nil {
+		t.Fatal(err)
+	}
+	rr = httptest.NewRecorder()
+	router.ServeHTTP(rr, req)
+
+	expected = []User{}
+	err = json.NewDecoder(rr.Body).Decode(&actual)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if !reflect.DeepEqual(actual, expected) {
+		t.Errorf("handler returned unexpected body: got %v want %v", actual, expected)
 	}
 }