diff options
| author | Petter Rasmussen | 2016-02-20 23:30:30 +0100 | 
|---|---|---|
| committer | Petter Rasmussen | 2016-02-20 23:39:12 +0100 | 
| commit | 28c4eb923fd01d892a17844328d0090830bcd229 (patch) | |
| tree | ead9f5722894f36c3b641df89a2226ce9094b4dd | |
| parent | a9e9da783481fcb8022eb52fb944cb9ee13997de (diff) | |
| download | gdrive-28c4eb923fd01d892a17844328d0090830bcd229.tar.bz2 | |
Wrap downloads in TimeoutReader
| -rw-r--r-- | drive/download.go | 7 | ||||
| -rw-r--r-- | drive/revision_download.go | 7 | ||||
| -rw-r--r-- | drive/sync_download.go | 12 | ||||
| -rw-r--r-- | drive/timeout_reader.go | 10 | 
4 files changed, 29 insertions, 7 deletions
| diff --git a/drive/download.go b/drive/download.go index 7d3dc8b..a33373f 100644 --- a/drive/download.go +++ b/drive/download.go @@ -75,7 +75,10 @@ func (self *Drive) downloadRecursive(args DownloadArgs) error {  }  func (self *Drive) downloadBinary(f *drive.File, args DownloadArgs) (int64, int64, error) { -    res, err := self.service.Files.Get(f.Id).Download() +    // Get timeout reader wrapper and context +    timeoutReaderWrapper, ctx := getTimeoutReaderWrapperContext() + +    res, err := self.service.Files.Get(f.Id).Context(ctx).Download()      if err != nil {          return 0, 0, fmt.Errorf("Failed to download file: %s", err)      } @@ -92,7 +95,7 @@ func (self *Drive) downloadBinary(f *drive.File, args DownloadArgs) (int64, int6      return self.saveFile(saveFileArgs{          out: args.Out, -        body: res.Body, +        body: timeoutReaderWrapper(res.Body),          contentLength: res.ContentLength,          fpath: fpath,          force: args.Force, diff --git a/drive/revision_download.go b/drive/revision_download.go index 9cc9d1d..039cd19 100644 --- a/drive/revision_download.go +++ b/drive/revision_download.go @@ -29,7 +29,10 @@ func (self *Drive) DownloadRevision(args DownloadRevisionArgs) (err error) {          return fmt.Errorf("Download is not supported for this file type")      } -    res, err := getRev.Download() +    // Get timeout reader wrapper and context +    timeoutReaderWrapper, ctx := getTimeoutReaderWrapperContext() + +    res, err := getRev.Context(ctx).Download()      if err != nil {          return fmt.Errorf("Failed to download file: %s", err)      } @@ -50,7 +53,7 @@ func (self *Drive) DownloadRevision(args DownloadRevisionArgs) (err error) {      bytes, rate, err := self.saveFile(saveFileArgs{          out: args.Out, -        body: res.Body, +        body: timeoutReaderWrapper(res.Body),          contentLength: res.ContentLength,          fpath: fpath,          force: args.Force, diff --git a/drive/sync_download.go b/drive/sync_download.go index fb7b3ae..5016cc1 100644 --- a/drive/sync_download.go +++ b/drive/sync_download.go @@ -187,7 +187,10 @@ func (self *Drive) downloadRemoteFile(id, fpath string, args DownloadSyncArgs, t          return nil      } -    res, err := self.service.Files.Get(id).Download() +    // Get timeout reader wrapper and context +    timeoutReaderWrapper, ctx := getTimeoutReaderWrapperContext() + +    res, err := self.service.Files.Get(id).Context(ctx).Download()      if err != nil {          if isBackendError(err) && try < MaxBackendErrorRetries {              exponentialBackoffSleep(try) @@ -202,7 +205,10 @@ func (self *Drive) downloadRemoteFile(id, fpath string, args DownloadSyncArgs, t      defer res.Body.Close()      // Wrap response body in progress reader -    srcReader := getProgressReader(res.Body, args.Progress, res.ContentLength) +    progressReader := getProgressReader(res.Body, args.Progress, res.ContentLength) + +    // Wrap reader in timeout reader +    reader := timeoutReaderWrapper(progressReader)      // Ensure any parent directories exists      if err = mkdir(fpath); err != nil { @@ -219,7 +225,7 @@ func (self *Drive) downloadRemoteFile(id, fpath string, args DownloadSyncArgs, t      }      // Save file to disk -    _, err = io.Copy(outFile, srcReader) +    _, err = io.Copy(outFile, reader)      if err != nil {          outFile.Close()          if try < MaxBackendErrorRetries { diff --git a/drive/timeout_reader.go b/drive/timeout_reader.go index ba2bb83..878911b 100644 --- a/drive/timeout_reader.go +++ b/drive/timeout_reader.go @@ -10,6 +10,16 @@ import (  const MaxIdleTimeout = time.Second * 120  const TimeoutTimerInterval = time.Second * 10 +type timeoutReaderWrapper func(io.Reader) io.Reader + +func getTimeoutReaderWrapperContext() (timeoutReaderWrapper, context.Context) { +    ctx, cancel := context.WithCancel(context.TODO()) +    wrapper := func(r io.Reader) io.Reader { +         return getTimeoutReader(r, cancel) +    } +    return wrapper, ctx +} +  func getTimeoutReaderContext(r io.Reader) (io.Reader, context.Context) {      ctx, cancel := context.WithCancel(context.TODO())      return getTimeoutReader(r, cancel), ctx | 
