aboutsummaryrefslogtreecommitdiffstats
path: root/drive/download.go
blob: 3ed73df2e1c3de12a05f31dffe8cbe943185bd8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
package drive

import (
    "fmt"
    "io"
    "os"
    "time"
    "path/filepath"
    "google.golang.org/api/drive/v3"
    "google.golang.org/api/googleapi"
)

type DownloadArgs struct {
    Out io.Writer
    Progress io.Writer
    Id string
    Path string
    Force bool
    Recursive bool
    Stdout bool
}

func (self *Drive) Download(args DownloadArgs) error {
    if args.Recursive {
        return self.downloadRecursive(args)
    }

    f, err := self.service.Files.Get(args.Id).Fields("id", "name", "size", "mimeType", "md5Checksum").Do()
    if err != nil {
        return fmt.Errorf("Failed to get file: %s", err)
    }

    if isDir(f) {
        return fmt.Errorf("'%s' is a directory, use --recursive to download directories", f.Name)
    }

    if !isBinary(f) {
        return fmt.Errorf("'%s' is a google document and must be exported, see the export command", f.Name)
    }

    bytes, rate, err := self.downloadBinary(f, args)

    if !args.Stdout {
        fmt.Fprintf(args.Out, "Downloaded %s at %s/s, total %s\n", f.Id, formatSize(rate, false), formatSize(bytes, false))
    }
    return err
}

func (self *Drive) downloadRecursive(args DownloadArgs) error {
    f, err := self.service.Files.Get(args.Id).Fields("id", "name", "size", "mimeType", "md5Checksum").Do()
    if err != nil {
        return fmt.Errorf("Failed to get file: %s", err)
    }

    if isDir(f) {
        return self.downloadDirectory(f, args)
    } else if isBinary(f) {
        _, _, err = self.downloadBinary(f, args)
        return err
    }

    return nil
}

func (self *Drive) downloadBinary(f *drive.File, args DownloadArgs) (int64, int64, error) {
    res, err := self.service.Files.Get(f.Id).Download()
    if err != nil {
        return 0, 0, fmt.Errorf("Failed to download file: %s", err)
    }

    // Close body on function exit
    defer res.Body.Close()

    // Path to file
    fpath := filepath.Join(args.Path, f.Name)

    if !args.Stdout {
        fmt.Fprintf(args.Out, "Downloading %s -> %s\n", f.Name, fpath)
    }

    return self.saveFile(saveFileArgs{
        out: args.Out,
        body: res.Body,
        contentLength: res.ContentLength,
        fpath: fpath,
        force: args.Force,
        stdout: args.Stdout,
        progress: args.Progress,
    })
}

type saveFileArgs struct {
    out io.Writer
    body io.Reader
    contentLength int64
    fpath string
    force bool
    stdout bool
    progress io.Writer
}

func (self *Drive) saveFile(args saveFileArgs) (int64, int64, error) {
    // Wrap response body in progress reader
    srcReader := getProgressReader(args.body, args.progress, args.contentLength)

    if args.stdout {
        // Write file content to stdout
        _, err := io.Copy(args.out, srcReader)
        return 0, 0, err
    }

    // Check if file exists
    if !args.force && fileExists(args.fpath) {
        return 0, 0, fmt.Errorf("File '%s' already exists, use --force to overwrite", args.fpath)
    }

    // Ensure any parent directories exists
    if err := mkdir(args.fpath); err != nil {
        return 0, 0, err
    }

    // Download to tmp file
    tmpPath := args.fpath + ".incomplete"

    // Create new file
    outFile, err := os.Create(tmpPath)
    if err != nil {
        return 0, 0, fmt.Errorf("Unable to create new file: %s", err)
    }

    started := time.Now()

    // Save file to disk
    bytes, err := io.Copy(outFile, srcReader)
    if err != nil {
        outFile.Close()
        os.Remove(tmpPath)
        return 0, 0, fmt.Errorf("Failed saving file: %s", err)
    }

    // Calculate average download rate
    rate := calcRate(bytes, started, time.Now())

    //if deleteSourceFile {
    //    self.Delete(args.Id)
    //}

    // Close File
    outFile.Close()

    // Rename tmp file to proper filename
    return bytes, rate, os.Rename(tmpPath, args.fpath)
}

func (self *Drive) downloadDirectory(parent *drive.File, args DownloadArgs) error {
    listArgs := listAllFilesArgs{
        query: fmt.Sprintf("'%s' in parents", parent.Id),
        fields: []googleapi.Field{"nextPageToken", "files(id,name)"},
    }
    files, err := self.listAllFiles(listArgs)
    if err != nil {
        return fmt.Errorf("Failed listing files: %s", err)
    }

    newPath := filepath.Join(args.Path, parent.Name)

    for _, f := range files {
        // Copy args and update changed fields
        newArgs := args
        newArgs.Path = newPath
        newArgs.Id = f.Id
        newArgs.Stdout = false

        err = self.downloadRecursive(newArgs)
        if err != nil {
            return err
        }
    }

    return nil
}

func isDir(f *drive.File) bool {
    return f.MimeType == DirectoryMimeType
}

func isBinary(f *drive.File) bool {
    return f.Md5Checksum != ""
}