membership: groupname and username validations (#116)
* group and username validations lowercase names
* fixed changes validations
* removed username validation
* lowercasing username
diff --git a/core/auth/memberships/main.go b/core/auth/memberships/main.go
index 4dae1a9..0a2bda4 100644
--- a/core/auth/memberships/main.go
+++ b/core/auth/memberships/main.go
@@ -9,6 +9,8 @@
"html/template"
"log"
"net/http"
+ "regexp"
+ "strings"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
@@ -34,6 +36,7 @@
Init(owner string, groups []string) error
CreateGroup(owner string, group Group) error
AddChildGroup(parent, child string) error
+ DoesGroupExist(group string) (bool, error)
GetGroupsOwnedBy(user string) ([]Group, error)
GetGroupsUserBelongsTo(user string) ([]Group, error)
IsGroupOwner(user, group string) (bool, error)
@@ -297,7 +300,25 @@
return exists, nil
}
+func (s *SQLiteStore) DoesGroupExist(group string) (bool, error) {
+ query := `SELECT EXISTS (SELECT 1 FROM groups WHERE name = ?)`
+ var exists bool
+ if err := s.db.QueryRow(query, group).Scan(&exists); err != nil {
+ return false, err
+ }
+ return exists, nil
+}
+
func (s *SQLiteStore) AddChildGroup(parent, child string) error {
+ if parent == child {
+ return fmt.Errorf("parent and child groups can not have same name")
+ }
+ if _, err := s.DoesGroupExist(parent); err != nil {
+ return fmt.Errorf("parent group name %s does not exist", parent)
+ }
+ if _, err := s.DoesGroupExist(child); err != nil {
+ return fmt.Errorf("child group name %s does not exist", child)
+ }
parentGroups, err := s.GetAllTransitiveGroupsForGroup(parent)
if err != nil {
return err
@@ -334,8 +355,7 @@
SELECT name FROM groups
WHERE name != ? AND name NOT IN (
SELECT child_group FROM group_to_group WHERE parent_group = ?
- )
- `
+ )`
rows, err := s.db.Query(query, group, group)
if err != nil {
return nil, err
@@ -457,17 +477,6 @@
Member
)
-func convertStatus(status string) (Status, error) {
- switch status {
- case "Owner":
- return Owner, nil
- case "Member":
- return Member, nil
- default:
- return Owner, fmt.Errorf("invalid status: %s", status)
- }
-}
-
func (s *Server) Start() {
router := mux.NewRouter()
router.PathPrefix("/static/").Handler(http.FileServer(http.FS(staticResources)))
@@ -557,6 +566,10 @@
}
var group Group
group.Name = r.PostFormValue("group-name")
+ if err := isValidGroupName(group.Name); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
group.Description = r.PostFormValue("description")
if err := s.store.CreateGroup(loggedInUser, group); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -573,6 +586,16 @@
}
vars := mux.Vars(r)
groupName := vars["group-name"]
+ exists, err := s.store.DoesGroupExist(groupName)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ if !exists {
+ errorMsg := fmt.Sprintf("group with the name '%s' not found", groupName)
+ http.Error(w, errorMsg, http.StatusNotFound)
+ return
+ }
tmpl, err := template.New("group").Parse(groupHTML)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -642,7 +665,15 @@
return
}
groupName := r.FormValue("group")
- username := r.FormValue("username")
+ if err := isValidGroupName(groupName); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ username := strings.ToLower(r.FormValue("username"))
+ if username == "" {
+ http.Error(w, "Username parameter is required", http.StatusBadRequest)
+ return
+ }
status, err := convertStatus(r.FormValue("status"))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -679,7 +710,15 @@
return
}
parentGroup := r.FormValue("parent-group")
+ if err := isValidGroupName(parentGroup); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
childGroup := r.FormValue("child-group")
+ if err := isValidGroupName(childGroup); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
if _, err := s.checkIsOwner(w, loggedInUser, parentGroup); err != nil {
return
}
@@ -714,10 +753,11 @@
func (s *Server) apiMemberOfHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
user, ok := vars["username"]
- if !ok {
+ if !ok || user == "" {
http.Error(w, "Username parameter is required", http.StatusBadRequest)
return
}
+ user = strings.ToLower(user)
transitiveGroups, err := s.store.GetAllTransitiveGroupsForUser(user)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -734,6 +774,28 @@
}
}
+func convertStatus(status string) (Status, error) {
+ switch status {
+ case "Owner":
+ return Owner, nil
+ case "Member":
+ return Member, nil
+ default:
+ return Owner, fmt.Errorf("invalid status: %s", status)
+ }
+}
+
+func isValidGroupName(group string) error {
+ if strings.TrimSpace(group) == "" {
+ return fmt.Errorf("group name can't be empty or contain only whitespaces")
+ }
+ validGroupName := regexp.MustCompile(`^[a-z0-9\-_:.\/ ]+$`)
+ if !validGroupName.MatchString(group) {
+ return fmt.Errorf("group name should contain only lowercase letters, digits, -, _, :, ., /")
+ }
+ return nil
+}
+
func main() {
flag.Parse()
db, err := sql.Open("sqlite3", *dbPath)