aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPetter Rasmussen2016-02-20 23:30:30 +0100
committerPetter Rasmussen2016-02-20 23:39:12 +0100
commit28c4eb923fd01d892a17844328d0090830bcd229 (patch)
treeead9f5722894f36c3b641df89a2226ce9094b4dd
parenta9e9da783481fcb8022eb52fb944cb9ee13997de (diff)
downloadgdrive-28c4eb923fd01d892a17844328d0090830bcd229.tar.bz2
Wrap downloads in TimeoutReader
-rw-r--r--drive/download.go7
-rw-r--r--drive/revision_download.go7
-rw-r--r--drive/sync_download.go12
-rw-r--r--drive/timeout_reader.go10
4 files changed, 29 insertions, 7 deletions
diff --git a/drive/download.go b/drive/download.go
index 7d3dc8b..a33373f 100644
--- a/drive/download.go
+++ b/drive/download.go
@@ -75,7 +75,10 @@ func (self *Drive) downloadRecursive(args DownloadArgs) error {
}
func (self *Drive) downloadBinary(f *drive.File, args DownloadArgs) (int64, int64, error) {
- res, err := self.service.Files.Get(f.Id).Download()
+ // Get timeout reader wrapper and context
+ timeoutReaderWrapper, ctx := getTimeoutReaderWrapperContext()
+
+ res, err := self.service.Files.Get(f.Id).Context(ctx).Download()
if err != nil {
return 0, 0, fmt.Errorf("Failed to download file: %s", err)
}
@@ -92,7 +95,7 @@ func (self *Drive) downloadBinary(f *drive.File, args DownloadArgs) (int64, int6
return self.saveFile(saveFileArgs{
out: args.Out,
- body: res.Body,
+ body: timeoutReaderWrapper(res.Body),
contentLength: res.ContentLength,
fpath: fpath,
force: args.Force,
diff --git a/drive/revision_download.go b/drive/revision_download.go
index 9cc9d1d..039cd19 100644
--- a/drive/revision_download.go
+++ b/drive/revision_download.go
@@ -29,7 +29,10 @@ func (self *Drive) DownloadRevision(args DownloadRevisionArgs) (err error) {
return fmt.Errorf("Download is not supported for this file type")
}
- res, err := getRev.Download()
+ // Get timeout reader wrapper and context
+ timeoutReaderWrapper, ctx := getTimeoutReaderWrapperContext()
+
+ res, err := getRev.Context(ctx).Download()
if err != nil {
return fmt.Errorf("Failed to download file: %s", err)
}
@@ -50,7 +53,7 @@ func (self *Drive) DownloadRevision(args DownloadRevisionArgs) (err error) {
bytes, rate, err := self.saveFile(saveFileArgs{
out: args.Out,
- body: res.Body,
+ body: timeoutReaderWrapper(res.Body),
contentLength: res.ContentLength,
fpath: fpath,
force: args.Force,
diff --git a/drive/sync_download.go b/drive/sync_download.go
index fb7b3ae..5016cc1 100644
--- a/drive/sync_download.go
+++ b/drive/sync_download.go
@@ -187,7 +187,10 @@ func (self *Drive) downloadRemoteFile(id, fpath string, args DownloadSyncArgs, t
return nil
}
- res, err := self.service.Files.Get(id).Download()
+ // Get timeout reader wrapper and context
+ timeoutReaderWrapper, ctx := getTimeoutReaderWrapperContext()
+
+ res, err := self.service.Files.Get(id).Context(ctx).Download()
if err != nil {
if isBackendError(err) && try < MaxBackendErrorRetries {
exponentialBackoffSleep(try)
@@ -202,7 +205,10 @@ func (self *Drive) downloadRemoteFile(id, fpath string, args DownloadSyncArgs, t
defer res.Body.Close()
// Wrap response body in progress reader
- srcReader := getProgressReader(res.Body, args.Progress, res.ContentLength)
+ progressReader := getProgressReader(res.Body, args.Progress, res.ContentLength)
+
+ // Wrap reader in timeout reader
+ reader := timeoutReaderWrapper(progressReader)
// Ensure any parent directories exists
if err = mkdir(fpath); err != nil {
@@ -219,7 +225,7 @@ func (self *Drive) downloadRemoteFile(id, fpath string, args DownloadSyncArgs, t
}
// Save file to disk
- _, err = io.Copy(outFile, srcReader)
+ _, err = io.Copy(outFile, reader)
if err != nil {
outFile.Close()
if try < MaxBackendErrorRetries {
diff --git a/drive/timeout_reader.go b/drive/timeout_reader.go
index ba2bb83..878911b 100644
--- a/drive/timeout_reader.go
+++ b/drive/timeout_reader.go
@@ -10,6 +10,16 @@ import (
const MaxIdleTimeout = time.Second * 120
const TimeoutTimerInterval = time.Second * 10
+type timeoutReaderWrapper func(io.Reader) io.Reader
+
+func getTimeoutReaderWrapperContext() (timeoutReaderWrapper, context.Context) {
+ ctx, cancel := context.WithCancel(context.TODO())
+ wrapper := func(r io.Reader) io.Reader {
+ return getTimeoutReader(r, cancel)
+ }
+ return wrapper, ctx
+}
+
func getTimeoutReaderContext(r io.Reader) (io.Reader, context.Context) {
ctx, cancel := context.WithCancel(context.TODO())
return getTimeoutReader(r, cancel), ctx