memberships: api init endpoint (#114)
Adds API endpoint to initialise database with first owner and groups.
affects: #108
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 4b64140..1358bb3 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -30,6 +30,8 @@
var staticResources embed.FS
type Store interface {
+ // Initializes store with admin user and their groups.
+ Init(owner string, groups []string) error
CreateGroup(owner string, group Group) error
AddChildGroup(parent, child string) error
GetGroupsOwnedBy(user string) ([]Group, error)
@@ -91,6 +93,33 @@
return &SQLiteStore{db: db}, nil
}
+func (s *SQLiteStore) Init(owner string, groups []string) error {
+ tx, err := s.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+ row := tx.QueryRow("SELECT COUNT(*) FROM groups")
+ var count int
+ if err := row.Scan(&count); err != nil {
+ return err
+ }
+ if count != 0 {
+ return fmt.Errorf("store already initialised")
+ }
+ for _, g := range groups {
+ query := `INSERT INTO groups (name, description) VALUES (?, '')`
+ if _, err := tx.Exec(query, g); err != nil {
+ return err
+ }
+ query = `INSERT INTO owners (username, group_name) VALUES (?, ?)`
+ if _, err := tx.Exec(query, owner, g); err != nil {
+ return err
+ }
+ }
+ return tx.Commit()
+}
+
func (s *SQLiteStore) queryGroups(query string, args ...interface{}) ([]Group, error) {
groups := make([]Group, 0)
rows, err := s.db.Query(query, args...)
@@ -147,10 +176,7 @@
if _, err := tx.Exec(query, owner, group.Name); err != nil {
return err
}
- if err := tx.Commit(); err != nil {
- return err
- }
- return nil
+ return tx.Commit()
}
func (s *SQLiteStore) IsGroupOwner(user, group string) (bool, error) {
@@ -445,6 +471,7 @@
router.HandleFunc("/create-group", s.createGroupHandler)
router.HandleFunc("/add-user", s.addUserHandler)
router.HandleFunc("/add-child-group", s.addChildGroupHandler)
+ router.HandleFunc("/api/init", s.apiInitHandler)
router.HandleFunc("/api/user/{username}", s.apiMemberOfHandler)
router.HandleFunc("/", s.homePageHandler)
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), router))
@@ -659,7 +686,24 @@
http.Redirect(w, r, "/group/"+parentGroup, http.StatusSeeOther)
}
-type UserInfo struct {
+type initRequest struct {
+ Owner string `json:"owner"`
+ Groups []string `json:"groups"`
+}
+
+func (s *Server) apiInitHandler(w http.ResponseWriter, r *http.Request) {
+ var req initRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ if err := s.store.Init(req.Owner, req.Groups); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+}
+
+type userInfo struct {
MemberOf []string `json:"memberOf"`
}
@@ -680,7 +724,7 @@
groupNames = append(groupNames, group.Name)
}
w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(UserInfo{MemberOf: groupNames}); err != nil {
+ if err := json.NewEncoder(w).Encode(userInfo{groupNames}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
diff --git a/core/auth/memberships/store_test.go b/core/auth/memberships/store_test.go
index eb774cc..d55b3f1 100644
--- a/core/auth/memberships/store_test.go
+++ b/core/auth/memberships/store_test.go
@@ -8,6 +8,53 @@
_ "github.com/ncruces/go-sqlite3/embed"
)
+func TestInitSuccess(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := store.Init("admin", []string{"admin", "all"}); err != nil {
+ t.Fatal(err)
+ }
+ groups, err := store.GetGroupsOwnedBy("admin")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(groups) != 2 {
+ t.Fatalf("Expected two groups, got: %s", groups)
+ }
+}
+
+func TestInitFailure(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ store, err := NewSQLiteStore(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.Exec(`
+ INSERT INTO groups (name, description)
+ VALUES
+ ('a', 'xxx'),
+ ('b', 'yyy');
+ `)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Init("admin", []string{"admin", "all"})
+ if err == nil {
+ t.Fatal("initialisation did not fail")
+ } else if err.Error() != "store already initialised" {
+ t.Fatalf("Expected initialisation error, got: %s", err.Error())
+ }
+}
+
func TestGetAllTransitiveGroupsForGroup(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {