Skip to content

Commit

Permalink
Add multipart download (#4)
Browse files Browse the repository at this point in the history
* Add multipart download

* actually exit fast on non-retryable errors

* Add multipart progress reporting
  • Loading branch information
gartnera authored Jan 6, 2025
1 parent b7033e1 commit dd55028
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 46 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: ci

on:
push:
branches:
- main
pull_request:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v5

- name: Run tests
run: go test -v -race ./...
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
dl-pipe https://example.invalid/my-file.tar | tar x
```

You may also provide the parts of a multipart tar file and it will be reassembled.

```
dl-pipe https://example.invalid/my-file.tar.part1 https://example.invalid/my-file.tar.part2 https://example.invalid/my-file.tar.part3 | tar x
```

We use this to workaround the 5TB size limit of most object storage providers.

We also provide an expected hash via the `-hash` option to ensure that the download content is correct. Make sure you set `set -eo pipefail` to ensure your script stops on errors.

Install with `go install github.com/zeta-chain/dl-pipe@latest`.
19 changes: 12 additions & 7 deletions cmd/dl-pipe/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const progressFuncInterval = time.Second * 10

func getProgressFunc() dlpipe.ProgressFunc {
prevLength := uint64(0)
return func(currentLength uint64, totalLength uint64) {
return func(currentLength uint64, totalLength uint64, currentPart int, totalParts int) {
currentLengthStr := humanize.Bytes(currentLength)
totalLengthStr := humanize.Bytes(totalLength)

Expand All @@ -81,7 +81,12 @@ func getProgressFunc() dlpipe.ProgressFunc {

percent := float64(currentLength) / float64(totalLength) * 100

fmt.Fprintf(os.Stderr, "Downloaded %s of %s (%.1f%%) at %s/s\n", currentLengthStr, totalLengthStr, percent, rateStr)
partStr := ""
if totalParts > 1 {
partStr = fmt.Sprintf(" (part %d of %d)", currentPart+1, totalParts)
}

fmt.Fprintf(os.Stderr, "Downloaded %s of %s (%.1f%%) at %s/s%s\n", currentLengthStr, totalLengthStr, percent, rateStr, partStr)
}
}

Expand All @@ -101,9 +106,9 @@ func main() {
flag.BoolVar(&progress, "progress", false, "Show download progress")
flag.Parse()

url := flag.Arg(0)
if url == "" {
fmt.Fprintf(os.Stderr, ("URL is required"))
urls := flag.Args()
if len(urls) == 0 {
fmt.Fprintf(os.Stderr, ("URL(s) are required"))
os.Exit(1)
}

Expand All @@ -119,9 +124,9 @@ func main() {
headerMap[parts[0]] = parts[1]
}

err := dlpipe.DownloadURL(
err := dlpipe.DownloadURLMultipart(
ctx,
url,
urls,
os.Stdout,
dlpipe.WithHeaders(headerMap),
getHashOpt(hash),
Expand Down
59 changes: 45 additions & 14 deletions download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package dlpipe
import (
"bytes"
"context"
"errors"
"fmt"
"hash"
"io"
"net/http"
"strings"
"sync"
"time"

"github.com/miolini/datacounter"
Expand Down Expand Up @@ -57,7 +59,7 @@ func WithHeaders(headers map[string]string) DownloadOpt {
}
}

type ProgressFunc func(currentLength uint64, totalLength uint64)
type ProgressFunc func(currentLength, totalLength uint64, currentPart, totalParts int)

func WithProgressFunc(progressFunc ProgressFunc, interval time.Duration) DownloadOpt {
return func(d *downloader) {
Expand Down Expand Up @@ -115,7 +117,7 @@ func DefaultRetryParameters() RetryParameters {

type downloader struct {
// these fields are set once
url string
urls []string
writer *datacounter.WriterCounter
httpClient *http.Client
retryParameters RetryParameters
Expand All @@ -129,6 +131,9 @@ type downloader struct {

// these fields are updated at runtime
contentLength int64
urlsPosition int

sync.RWMutex
}

func (d *downloader) progressReportLoop(ctx context.Context) {
Expand All @@ -137,15 +142,19 @@ func (d *downloader) progressReportLoop(ctx context.Context) {
for {
select {
case <-t.C:
d.progressFunc(d.writer.Count(), uint64(d.contentLength))
d.RLock()
d.progressFunc(d.writer.Count(), uint64(d.contentLength), d.urlsPosition, d.totalPartCount())
d.RUnlock()
case <-ctx.Done():
return
}
}
}

func (d *downloader) runInner(ctx context.Context) (io.ReadCloser, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil)
d.RLock()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.urls[d.urlsPosition], nil)
d.RUnlock()
if err != nil {
return nil, NonRetryableWrapf("create request: %w", err)
}
Expand Down Expand Up @@ -176,7 +185,9 @@ func (d *downloader) runInner(ctx context.Context) (io.ReadCloser, error) {
}

if resp.StatusCode != http.StatusPartialContent {
return nil, NonRetryableWrapf("unexpected status code on subsequent read: %d", resp.StatusCode)
// this error should be retried since cloudflare r2 sometimes ignores the range request and
// returns 200
return nil, fmt.Errorf("unexpected status code on subsequent read: %d", resp.StatusCode)
}

// Validate we are receiving the right portion of partial content
Expand Down Expand Up @@ -212,15 +223,23 @@ func (d *downloader) run(ctx context.Context) error {
if d.progressFunc != nil {
go d.progressReportLoop(ctx)
}
for {
d.resetWriterPosition()

for d.urlsPosition < d.totalPartCount() {
body, err := d.runInner(ctx)
if err != nil {
return err
}
defer body.Close()
_, err = io.Copy(d.writer, body)
if err == nil {
break
defer body.Close()
_, err = io.Copy(d.writer, body)
if err == nil {
d.Lock()
d.urlsPosition++
d.resetWriterPosition()
d.Unlock()
continue
}
}
if errors.Is(err, ErrNonRetryable{}) {
return err
}
err = d.retryParameters.Wait(ctx, d.writer.Count())
if err != nil {
Expand All @@ -236,9 +255,22 @@ func (d *downloader) run(ctx context.Context) error {
return nil
}

func (d *downloader) resetWriterPosition() {
d.writer = datacounter.NewWriterCounter(d.tmpWriter)
d.contentLength = 0
}

func (d *downloader) totalPartCount() int {
return len(d.urls)
}

func DownloadURL(ctx context.Context, url string, writer io.Writer, opts ...DownloadOpt) error {
return DownloadURLMultipart(ctx, []string{url}, writer, opts...)
}

func DownloadURLMultipart(ctx context.Context, urls []string, writer io.Writer, opts ...DownloadOpt) error {
d := &downloader{
url: url,
urls: urls,
tmpWriter: writer,
httpClient: &http.Client{
Transport: &http.Transport{
Expand All @@ -254,6 +286,5 @@ func DownloadURL(ctx context.Context, url string, writer io.Writer, opts ...Down
}
opt(d)
}
d.writer = datacounter.NewWriterCounter(d.tmpWriter)
return d.run(ctx)
}
94 changes: 69 additions & 25 deletions download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
"testing"

Expand All @@ -26,12 +28,12 @@ func TestUninterruptedDownload(t *testing.T) {
r := require.New(t)
ctx := context.Background()

serverURL, expectedHash, cleanup := serveInterruptedTestFile(t, fileSize, 0)
serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 1)
defer cleanup()

hasher := sha256.New()

err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, expectedHash))
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
r.NoError(err)

givenHash := hasher.Sum(nil)
Expand All @@ -42,54 +44,96 @@ func TestUninterruptedMismatch(t *testing.T) {
r := require.New(t)
ctx := context.Background()

serverURL, _, cleanup := serveInterruptedTestFile(t, fileSize, 0)
serverURLs, _, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 1)
defer cleanup()

hasher := sha256.New()

err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, []byte{}))
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, []byte{}))
r.Error(err)
}

func TestInterruptedDownload(t *testing.T) {
r := require.New(t)
ctx := context.Background()

serverURL, expectedHash, cleanup := serveInterruptedTestFile(t, fileSize, interruptAt)
serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, interruptAt, 1)
defer cleanup()

hasher := sha256.New()

err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, expectedHash))
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
r.NoError(err)
}

// derrived from https://github.com/vansante/go-dl-stream/blob/e29aef86498f37d3506126bc258193f1c913ea55/download_test.go#L166
func serveInterruptedTestFile(t *testing.T, fileSize, interruptAt int64) (serverURL string, sha256Hash []byte, cleanup func()) {
rndFile, err := os.CreateTemp(os.TempDir(), "random_file_*.rnd")
assert.NoError(t, err)
filePath := rndFile.Name()
func TestDownloadMultipart(t *testing.T) {
r := require.New(t)
ctx := context.Background()

serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 10)
defer cleanup()

hasher := sha256.New()
_, err = io.Copy(io.MultiWriter(hasher, rndFile), io.LimitReader(rand.Reader, fileSize))
assert.NoError(t, err)
assert.NoError(t, rndFile.Close())

mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
log.Printf("Serving random interrupted file (size: %d, interuptAt: %d), Range: %s", fileSize, interruptAt, request.Header.Get(rangeHeader))
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
r.NoError(err)
}

func TestDownloadMultipartInterrupted(t *testing.T) {
r := require.New(t)
ctx := context.Background()

serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, interruptAt, 10)
defer cleanup()

hasher := sha256.New()

http.ServeFile(&interruptibleHTTPWriter{
ResponseWriter: writer,
writer: writer,
interruptAt: interruptAt,
}, request, filePath)
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
r.NoError(err)
}

func TestErrNonRetryable(t *testing.T) {
err := NonRetryableWrapf("test")
require.True(t, errors.Is(err, ErrNonRetryable{}))
}

})
// derrived from https://github.com/vansante/go-dl-stream/blob/e29aef86498f37d3506126bc258193f1c913ea55/download_test.go#L166
func serveInterruptedTestFiles(t *testing.T, fileSize, interruptAt int64, parts int) ([]string, []byte, func()) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
hasher := sha256.New()
filePaths := []string{}
urls := []string{}

for i := 0; i < parts; i++ {
rndFile, err := os.CreateTemp(os.TempDir(), "random_file_*.rnd")
assert.NoError(t, err)
filePath := rndFile.Name()
filePaths = append(filePaths, filePath)
filePathBase := filepath.Base(filePath)

return server.URL, hasher.Sum(nil), func() {
_ = os.Remove(filePath)
_, err = io.Copy(io.MultiWriter(hasher, rndFile), io.LimitReader(rand.Reader, fileSize))
assert.NoError(t, err)
assert.NoError(t, rndFile.Close())

mux.HandleFunc(filePath, func(writer http.ResponseWriter, request *http.Request) {
log.Printf("Serving random interrupted file %s (size: %d, interuptAt: %d), Range: %s", filePathBase, fileSize, interruptAt, request.Header.Get(rangeHeader))

http.ServeFile(&interruptibleHTTPWriter{
ResponseWriter: writer,
writer: writer,
interruptAt: interruptAt,
}, request, filePath)

})
urls = append(urls, server.URL+filePath)

}

return urls, hasher.Sum(nil), func() {
for _, path := range filePaths {
_ = os.Remove(path)
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ func (e ErrNonRetryable) Unwrap() error {
return e.inner
}

func (e ErrNonRetryable) Is(target error) bool {
_, ok := target.(ErrNonRetryable)
return ok
}

func NonRetryableWrap(err error) error {
return ErrNonRetryable{inner: err}
}
Expand Down

0 comments on commit dd55028

Please sign in to comment.