Skip to content

Commit

Permalink
Merge pull request #19 from gadget-inc/close_open_query
Browse files Browse the repository at this point in the history
Close open queries when function early exits
  • Loading branch information
angelini authored Oct 28, 2022
2 parents 035a1cd + 7ba0561 commit bddb60b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
32 changes: 17 additions & 15 deletions internal/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func filterObject(path string, objectQuery *pb.ObjectQuery, object *pb.Object) (
return nil, SKIP
}

func GetObjects(ctx context.Context, tx pgx.Tx, packManager *PackManager, project int64, vrange VersionRange, objectQuery *pb.ObjectQuery) (ObjectStream, error) {
func GetObjects(ctx context.Context, tx pgx.Tx, packManager *PackManager, project int64, vrange VersionRange, objectQuery *pb.ObjectQuery) (ObjectStream, CloseFunc, error) {
packParent := packManager.IsPathPacked(objectQuery.Path)

originalPath := objectQuery.Path
Expand All @@ -167,10 +167,11 @@ func GetObjects(ctx context.Context, tx pgx.Tx, packManager *PackManager, projec

builder := newQueryBuilder(project, vrange, objectQuery)
sql, args := builder.build()
rows, err := tx.Query(ctx, sql, args...)

rows, err := tx.Query(ctx, sql, args...)
closeFunc := func(_ context.Context) { rows.Close() }
if err != nil {
return nil, fmt.Errorf("getObjects query, project %v vrange %v: %w", project, vrange, err)
return nil, closeFunc, fmt.Errorf("getObjects query, project %v vrange %v: %w", project, vrange, err)
}

var buffer []*pb.Object
Expand Down Expand Up @@ -227,17 +228,19 @@ func GetObjects(ctx context.Context, tx pgx.Tx, packManager *PackManager, projec
Deleted: deleted,
Content: content,
})
}, nil
}, closeFunc, nil
}

type tarStream func() ([]byte, *string, error)

func GetTars(ctx context.Context, tx pgx.Tx, project int64, cacheVersions []int64, vrange VersionRange, objectQuery *pb.ObjectQuery) (tarStream, error) {
func GetTars(ctx context.Context, tx pgx.Tx, project int64, cacheVersions []int64, vrange VersionRange, objectQuery *pb.ObjectQuery) (tarStream, CloseFunc, error) {
builder := newQueryBuilder(project, vrange, objectQuery).withCacheVersions(cacheVersions).withHashes(true)
sql, args := builder.build()

rows, err := tx.Query(ctx, sql, args...)
closeFunc := func(_ context.Context) { rows.Close() }
if err != nil {
return nil, fmt.Errorf("getObjects query, project %v vrange %v: %w", project, vrange, err)
return nil, closeFunc, fmt.Errorf("getObjects query, project %v vrange %v: %w", project, vrange, err)
}

tarWriter := NewTarWriter()
Expand Down Expand Up @@ -294,12 +297,12 @@ func GetTars(ctx context.Context, tx pgx.Tx, project int64, cacheVersions []int6
}

return nil, nil, SKIP
}, nil
}, closeFunc, nil
}

type cacheTarStream func() (int64, []byte, *Hash, error)

func GetCacheTars(ctx context.Context, tx pgx.Tx) (cacheTarStream, error) {
func GetCacheTars(ctx context.Context, tx pgx.Tx) (cacheTarStream, CloseFunc, error) {
var version int64

err := tx.QueryRow(ctx, `
Expand All @@ -309,10 +312,10 @@ func GetCacheTars(ctx context.Context, tx pgx.Tx) (cacheTarStream, error) {
LIMIT 1
`).Scan(&version)
if err == pgx.ErrNoRows {
return func() (int64, []byte, *Hash, error) { return 0, nil, nil, io.EOF }, nil
return func() (int64, []byte, *Hash, error) { return 0, nil, nil, io.EOF }, func(_ context.Context) {}, nil
}
if err != nil {
return nil, fmt.Errorf("GetCacheTars latest cache version: %w", err)
return nil, func(_ context.Context) {}, fmt.Errorf("GetCacheTars latest cache version: %w", err)
}

rows, err := tx.Query(ctx, `
Expand All @@ -326,11 +329,12 @@ func GetCacheTars(ctx context.Context, tx pgx.Tx) (cacheTarStream, error) {
JOIN dl.contents c
ON h.hash = c.hash
`, version)
closeFunc := func(_ context.Context) { rows.Close() }
if err != nil {
return nil, fmt.Errorf("GetCacheTars query: %w", err)
return nil, closeFunc, fmt.Errorf("GetCacheTars query: %w", err)
}

stream := func() (int64, []byte, *Hash, error) {
return func() (int64, []byte, *Hash, error) {
if !rows.Next() {
return 0, nil, nil, io.EOF
}
Expand All @@ -344,9 +348,7 @@ func GetCacheTars(ctx context.Context, tx pgx.Tx) (cacheTarStream, error) {
}

return version, encoded, &hash, nil
}

return stream, nil
}, closeFunc, nil
}

type PackManager struct {
Expand Down
12 changes: 8 additions & 4 deletions pkg/api/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ func (f *Fs) Get(req *pb.GetRequest, stream pb.Fs_GetServer) error {
key.QueryIgnores.Field(query.Ignores),
)

objects, err := db.GetObjects(ctx, tx, packManager, req.Project, vrange, query)
objects, closeFunc, err := db.GetObjects(ctx, tx, packManager, req.Project, vrange, query)
defer closeFunc(ctx)
if err != nil {
return status.Errorf(codes.Internal, "FS get objects: %v", err)
}
Expand Down Expand Up @@ -365,7 +366,8 @@ func (f *Fs) GetCompress(req *pb.GetCompressRequest, stream pb.Fs_GetCompressSer
key.QueryIgnores.Field(query.Ignores),
)

tars, err := db.GetTars(ctx, tx, req.Project, req.AvailableCacheVersions, vrange, query)
tars, closeFunc, err := db.GetTars(ctx, tx, req.Project, req.AvailableCacheVersions, vrange, query)
defer closeFunc(ctx)
if err != nil {
return status.Errorf(codes.Internal, "FS get tars: %v", err)
}
Expand Down Expand Up @@ -588,7 +590,8 @@ func (f *Fs) Inspect(ctx context.Context, req *pb.InspectRequest) (*pb.InspectRe
Path: "",
IsPrefix: true,
}
objects, err := db.GetObjects(ctx, tx, packManager, req.Project, vrange, query)
objects, closeFunc, err := db.GetObjects(ctx, tx, packManager, req.Project, vrange, query)
defer closeFunc(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "FS get objects: %v", err)
}
Expand Down Expand Up @@ -868,7 +871,8 @@ func (f *Fs) GetCache(req *pb.GetCacheRequest, stream pb.Fs_GetCacheServer) erro

logger.Debug(ctx, "FS.GetCache[Init]")

tars, err := db.GetCacheTars(ctx, tx)
tars, closeFunc, err := db.GetCacheTars(ctx, tx)
defer closeFunc(ctx)
if err != nil {
return status.Errorf(codes.Internal, "FS get cached tars: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion test/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ func TestGetCacheWithMultipleVersions(t *testing.T) {
Path: "pack",
IsPrefix: true,
}
tars, err := db.GetTars(tc.Context(), tc.Connect(), 1, availableVersions, vrange, query)
tars, closeFunc, err := db.GetTars(tc.Context(), tc.Connect(), 1, availableVersions, vrange, query)
defer closeFunc(tc.Context())
require.NoError(t, err)

var paths []string
Expand Down

0 comments on commit bddb60b

Please sign in to comment.