Skip to content

Commit

Permalink
[server] Enable metadata r/w for shared files (#4569)
Browse files Browse the repository at this point in the history
## Description

## Tests
Will test happy cases and update here
  • Loading branch information
ua741 authored Jan 8, 2025
2 parents a33f5b8 + db4b560 commit 8656f69
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 16 deletions.
1 change: 1 addition & 0 deletions server/pkg/controller/access/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type Controller interface {
GetCollection(ctx *gin.Context, req *GetCollectionParams) (*GetCollectionResponse, error)
VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwnershipParams) error
CanAccessFile(ctx *gin.Context, req *CanAccessFileParams) error
}

// controllerImpl implements Controller
Expand Down
40 changes: 40 additions & 0 deletions server/pkg/controller/access/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ type VerifyFileOwnershipParams struct {
FileIDs []int64
}

type CanAccessFileParams struct {
ActorUserID int64
FileIDs []int64
}

// VerifyFileOwnership will return error if given fileIDs are not valid or don't belong to the ownerID
func (c controllerImpl) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwnershipParams) error {
if enteArray.ContainsDuplicateInInt64Array(req.FileIDs) {
Expand All @@ -26,3 +31,38 @@ func (c controllerImpl) VerifyFileOwnership(ctx *gin.Context, req *VerifyFileOwn
})
return c.FileRepo.VerifyFileOwner(ctx, req.FileIDs, ownerID, logger)
}
func (c controllerImpl) CanAccessFile(ctx *gin.Context, req *CanAccessFileParams) error {
if enteArray.ContainsDuplicateInInt64Array(req.FileIDs) {
return stacktrace.Propagate(ente.ErrBadRequest, "duplicate fileIDs")
}

ownerToFilesMap, err := c.FileRepo.GetOwnerToFileIDsMap(ctx, req.FileIDs)
if err != nil {
return stacktrace.Propagate(err, "failed to get owner to fileIDs map")
}

// Only fetch shared collections once when needed
var sharedCollections []int64
for owner, fileIDs := range ownerToFilesMap {
if owner == req.ActorUserID {
continue
}

// Lazy load collections only when we need to check permissions
if sharedCollections == nil {
sharedCollections, err = c.CollectionRepo.GetCollectionsSharedWithOrByUser(req.ActorUserID)
if err != nil {
return stacktrace.Propagate(err, "failed to get shared collections")
}
}
if existsErr := c.CollectionRepo.DoAllFilesExistInGivenCollections(fileIDs, sharedCollections); existsErr != nil {
log.WithFields(log.Fields{
"req_id": requestid.Get(ctx),
"sharedCollections": sharedCollections,
"fileIDs": fileIDs,
}).WithError(existsErr).Error("access check failed")
return stacktrace.Propagate(ente.ErrPermissionDenied, "access denied")
}
}
return nil
}
33 changes: 20 additions & 13 deletions server/pkg/controller/filedata/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,20 @@ func (c *Controller) InsertOrUpdateMetadata(ctx *gin.Context, req *fileData.PutF
return stacktrace.Propagate(err, "validation failed")
}
userID := auth.GetUserID(ctx.Request.Header)
err := c._validatePermission(ctx, req.FileID, userID)
fileOwnerID, err := c.FileRepo.GetOwnerID(req.FileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
if fileOwnerID != userID {
permErr := c._checkMetadataReadOrWritePerm(ctx, userID, []int64{req.FileID})
if permErr != nil {
return stacktrace.Propagate(permErr, "")
}
}
if req.Type != ente.MlData {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("unsupported object type "+string(req.Type)), "")
}
fileOwnerID := userID

bucketID := c.S3Config.GetBucketID(req.Type)
objectKey := fileData.ObjectMetadataKey(req.FileID, fileOwnerID, req.Type, nil)
obj := fileData.S3FileMetadata{
Expand Down Expand Up @@ -123,10 +129,11 @@ func (c *Controller) InsertOrUpdateMetadata(ctx *gin.Context, req *fileData.PutF
}

func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*fileData.Entity, error) {
userID := auth.GetUserID(ctx.Request.Header)
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "validation failed")
}
if err := c._validatePermission(ctx, req.FileID, auth.GetUserID(ctx.Request.Header)); err != nil {
if err := c._checkMetadataReadOrWritePerm(ctx, userID, []int64{req.FileID}); err != nil {
return nil, stacktrace.Propagate(err, "")
}
doRows, err := c.Repo.GetFilesData(ctx, req.Type, []int64{req.FileID})
Expand All @@ -150,7 +157,10 @@ func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*f

func (c *Controller) GetFilesData(ctx *gin.Context, req fileData.GetFilesData) (*fileData.GetFilesDataResponse, error) {
userID := auth.GetUserID(ctx.Request.Header)
if err := c._validateGetFilesData(ctx, userID, req); err != nil {
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "req validation failed")
}
if err := c._checkMetadataReadOrWritePerm(ctx, userID, req.FileIDs); err != nil {
return nil, stacktrace.Propagate(err, "")
}

Expand Down Expand Up @@ -273,21 +283,18 @@ func (c *Controller) fetchS3FileMetadata(ctx context.Context, row fileData.Row,
return nil, stacktrace.Propagate(errors.New("failed to fetch object"), "")
}

func (c *Controller) _validateGetFilesData(ctx *gin.Context, userID int64, req fileData.GetFilesData) error {
if err := req.Validate(); err != nil {
return stacktrace.Propagate(err, "validation failed")
}
if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: userID,
FileIDs: req.FileIDs,
func (c *Controller) _checkMetadataReadOrWritePerm(ctx *gin.Context, userID int64, fileIDs []int64) error {
if err := c.AccessCtrl.CanAccessFile(ctx, &access.CanAccessFileParams{
ActorUserID: userID,
FileIDs: fileIDs,
}); err != nil {
return stacktrace.Propagate(err, "User does not own some file(s)")
}

return nil
}

func (c *Controller) _validatePermission(ctx *gin.Context, fileID int64, actorID int64) error {
// _checkPreviewWritePerm is
func (c *Controller) _checkPreviewWritePerm(ctx *gin.Context, fileID int64, actorID int64) error {
err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: actorID,
FileIDs: []int64{fileID},
Expand Down
4 changes: 2 additions & 2 deletions server/pkg/controller/filedata/preview_files.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func (c *Controller) GetPreviewUrl(ctx *gin.Context, request filedata.GetPreview
return nil, err
}
actorUser := auth.GetUserID(ctx.Request.Header)
if err := c._validatePermission(ctx, request.FileID, actorUser); err != nil {
if err := c._checkMetadataReadOrWritePerm(ctx, actorUser, []int64{request.FileID}); err != nil {
return nil, err
}
data, err := c.Repo.GetFilesData(ctx, request.Type, []int64{request.FileID})
Expand All @@ -35,7 +35,7 @@ func (c *Controller) PreviewUploadURL(ctx *gin.Context, request filedata.Preview
return nil, err
}
actorUser := auth.GetUserID(ctx.Request.Header)
if err := c._validatePermission(ctx, request.FileID, actorUser); err != nil {
if err := c._checkPreviewWritePerm(ctx, request.FileID, actorUser); err != nil {
return nil, err
}
fileOwnerID, err := c.FileRepo.GetOwnerID(request.FileID)
Expand Down
2 changes: 1 addition & 1 deletion server/pkg/controller/filedata/video.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (c *Controller) InsertVideoPreview(ctx *gin.Context, req *filedata.VidPrevi
return stacktrace.Propagate(err, "validation failed")
}
userID := auth.GetUserID(ctx.Request.Header)
err := c._validatePermission(ctx, req.FileID, userID)
err := c._checkPreviewWritePerm(ctx, req.FileID, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}
Expand Down
69 changes: 69 additions & 0 deletions server/pkg/repo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,29 @@ func (repo *CollectionRepository) GetCollectionIDsSharedWithUser(userID int64) (
return cIDs, nil
}

func (repo *CollectionRepository) GetCollectionsSharedWithOrByUser(userID int64) ([]int64, error) {
rows, err := repo.DB.Query(`
SELECT collection_id
FROM collection_shares
WHERE (to_user_id = $1 OR from_user_id = $1)
AND is_deleted = $2`, userID, false)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
defer rows.Close()

cIDs := make([]int64, 0)
for rows.Next() {
var cID int64
if err := rows.Scan(&cID); err != nil {
return cIDs, stacktrace.Propagate(err, "")
}
cIDs = append(cIDs, cID)
}
return cIDs, nil

}

// GetCollectionIDsOwnedByUser returns the map of collectionID (owned by user) to collection deletion status
func (repo *CollectionRepository) GetCollectionIDsOwnedByUser(userID int64) (map[int64]bool, error) {
rows, err := repo.DB.Query(`
Expand Down Expand Up @@ -375,6 +398,52 @@ func (repo *CollectionRepository) DoesFileExistInCollections(fileID int64, cIDs
return exists, stacktrace.Propagate(err, "")
}

func (repo *CollectionRepository) DoAllFilesExistInGivenCollections(fileIDs []int64, cIDs []int64) error {
// Query to get all distinct file_ids that exist in the collections
rows, err := repo.DB.Query(`
SELECT DISTINCT file_id
FROM collection_files
WHERE file_id = ANY ($1)
AND is_deleted = false
AND collection_id = ANY ($2)`,
pq.Array(fileIDs), pq.Array(cIDs))

if err != nil {
return stacktrace.Propagate(err, "")
}
defer rows.Close()

// Create a map of input fileIDs for easy lookup
fileIDMap := make(map[int64]bool)
for _, id := range fileIDs {
fileIDMap[id] = false // false means not found yet
}
// Mark files that were found
for rows.Next() {
var fileID int64
if err := rows.Scan(&fileID); err != nil {
return stacktrace.Propagate(err, "")
}
fileIDMap[fileID] = true // mark as found
}

if err = rows.Err(); err != nil {
return stacktrace.Propagate(err, "")
}

// Collect missing files
var missingFiles []int64
for id, found := range fileIDMap {
if !found {
missingFiles = append(missingFiles, id)
}
}
if len(missingFiles) > 0 {
return stacktrace.Propagate(fmt.Errorf("missing files %v", missingFiles), "")
}
return nil
}

// VerifyAllFileIDsExistsInCollection returns error if the fileIDs don't exist in the collection
func (repo *CollectionRepository) VerifyAllFileIDsExistsInCollection(ctx context.Context, cID int64, fileIDs []int64) error {
fileIdMap := make(map[int64]bool)
Expand Down

0 comments on commit 8656f69

Please sign in to comment.