|
| 1 | +package storage |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "crypto/sha256" |
| 7 | + "encoding/hex" |
| 8 | + "encoding/json" |
| 9 | + "fmt" |
| 10 | + "io" |
| 11 | + "net" |
| 12 | + "net/http" |
| 13 | + "strings" |
| 14 | + "time" |
| 15 | + |
| 16 | + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" |
| 17 | +) |
| 18 | + |
| 19 | +const ( |
| 20 | + // BundleFormatTar is the tar archive format for bundle responses. |
| 21 | + BundleFormatTar = "tar" |
| 22 | + |
| 23 | + // BundleCompressionNone disables compression (default). |
| 24 | + BundleCompressionNone = "none" |
| 25 | + // BundleCompressionGzip enables gzip compression. |
| 26 | + BundleCompressionGzip = "gzip" |
| 27 | + // BundleCompressionZstd enables zstd compression. |
| 28 | + BundleCompressionZstd = "zstd" |
| 29 | + |
| 30 | + // BundleOnErrorSkip silently omits missing objects from the archive (default). |
| 31 | + BundleOnErrorSkip = "skip" |
| 32 | + // BundleOnErrorFail returns an error if any object is missing. |
| 33 | + BundleOnErrorFail = "fail" |
| 34 | +) |
| 35 | + |
| 36 | +// bundleHTTPClient is reused across calls. No overall timeout — the caller's |
| 37 | +// context controls cancellation, which avoids cutting off streaming reads. |
| 38 | +var bundleHTTPClient = &http.Client{ |
| 39 | + Transport: &http.Transport{ |
| 40 | + Proxy: http.ProxyFromEnvironment, |
| 41 | + DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, |
| 42 | + TLSHandshakeTimeout: 10 * time.Second, |
| 43 | + ResponseHeaderTimeout: 60 * time.Second, |
| 44 | + }, |
| 45 | +} |
| 46 | + |
| 47 | +// BundleObjectsInput is the input for a BundleObjects request. |
| 48 | +type BundleObjectsInput struct { |
| 49 | + // Bucket is the name of the bucket containing the objects. Required. |
| 50 | + Bucket string |
| 51 | + |
| 52 | + // Keys is the list of object keys to include in the bundle. Required. |
| 53 | + // Maximum 5,000 keys per request. |
| 54 | + Keys []string |
| 55 | + |
| 56 | + // Compression sets the compression algorithm for the response. |
| 57 | + // Valid values: "none" (default), "gzip", "zstd". |
| 58 | + Compression string |
| 59 | + |
| 60 | + // OnError controls behavior when objects are missing. |
| 61 | + // "skip" (default): omit missing objects and append __bundle_errors.json to the tar. |
| 62 | + // "fail": return an error before streaming if any object is missing. |
| 63 | + OnError string |
| 64 | +} |
| 65 | + |
| 66 | +// BundleObjectsOutput is the response from a BundleObjects request. |
| 67 | +// |
| 68 | +// The Body contains a streaming tar archive. Callers are responsible for closing Body. |
| 69 | +// Use archive/tar to iterate entries: |
| 70 | +// |
| 71 | +// tr := tar.NewReader(output.Body) |
| 72 | +// for { |
| 73 | +// hdr, err := tr.Next() |
| 74 | +// if err == io.EOF { break } |
| 75 | +// // process hdr.Name, tr |
| 76 | +// } |
| 77 | +// |
| 78 | +// If compression was requested, wrap Body with the appropriate decompressor first: |
| 79 | +// |
| 80 | +// gz, _ := gzip.NewReader(output.Body) |
| 81 | +// tr := tar.NewReader(gz) |
| 82 | +type BundleObjectsOutput struct { |
| 83 | + // Body is the streaming tar archive. Must be closed by the caller. |
| 84 | + Body io.ReadCloser |
| 85 | + |
| 86 | + // ContentType is the response Content-Type (e.g. "application/x-tar", "application/gzip"). |
| 87 | + ContentType string |
| 88 | + |
| 89 | + // StatusCode is the HTTP status code of the response. |
| 90 | + StatusCode int |
| 91 | +} |
| 92 | + |
| 93 | +type bundleRequestBody struct { |
| 94 | + Keys []string `json:"keys"` |
| 95 | +} |
| 96 | + |
| 97 | +// BundleObjects fetches multiple objects from a bucket as a streaming tar archive |
| 98 | +// in a single HTTP request. |
| 99 | +// |
| 100 | +// This is a Tigris extension to the S3 API, designed for ML training workloads |
| 101 | +// that need to fetch thousands of objects per batch without per-object HTTP overhead. |
| 102 | +// |
| 103 | +// The caller is responsible for closing the returned Body. |
| 104 | +func (c *Client) BundleObjects(ctx context.Context, in *BundleObjectsInput) (*BundleObjectsOutput, error) { |
| 105 | + if in.Bucket == "" { |
| 106 | + return nil, fmt.Errorf("storage: BundleObjects: bucket is required") |
| 107 | + } |
| 108 | + if len(in.Keys) == 0 { |
| 109 | + return nil, fmt.Errorf("storage: BundleObjects: at least one key is required") |
| 110 | + } |
| 111 | + |
| 112 | + compression := in.Compression |
| 113 | + if compression == "" { |
| 114 | + compression = BundleCompressionNone |
| 115 | + } |
| 116 | + |
| 117 | + onError := in.OnError |
| 118 | + if onError == "" { |
| 119 | + onError = BundleOnErrorSkip |
| 120 | + } |
| 121 | + |
| 122 | + opts := c.Client.Options() |
| 123 | + |
| 124 | + endpoint := GlobalEndpoint |
| 125 | + if opts.BaseEndpoint != nil { |
| 126 | + endpoint = *opts.BaseEndpoint |
| 127 | + } |
| 128 | + endpoint = strings.TrimRight(endpoint, "/") |
| 129 | + |
| 130 | + reqURL := fmt.Sprintf("%s/%s?bundle", endpoint, in.Bucket) |
| 131 | + |
| 132 | + body, err := json.Marshal(bundleRequestBody{Keys: in.Keys}) |
| 133 | + if err != nil { |
| 134 | + return nil, fmt.Errorf("storage: BundleObjects: failed to marshal keys: %w", err) |
| 135 | + } |
| 136 | + |
| 137 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(body)) |
| 138 | + if err != nil { |
| 139 | + return nil, fmt.Errorf("storage: BundleObjects: failed to create request: %w", err) |
| 140 | + } |
| 141 | + |
| 142 | + req.Header.Set("Content-Type", "application/json") |
| 143 | + req.Header.Set("X-Tigris-Bundle-Format", BundleFormatTar) |
| 144 | + req.Header.Set("X-Tigris-Bundle-Compression", compression) |
| 145 | + req.Header.Set("X-Tigris-Bundle-On-Error", onError) |
| 146 | + |
| 147 | + // Sign request with SigV4. |
| 148 | + if opts.Credentials != nil { |
| 149 | + creds, err := opts.Credentials.Retrieve(ctx) |
| 150 | + if err != nil { |
| 151 | + return nil, fmt.Errorf("storage: BundleObjects: failed to retrieve credentials: %w", err) |
| 152 | + } |
| 153 | + |
| 154 | + payloadHash := sha256Hex(body) |
| 155 | + req.Header.Set("X-Amz-Content-Sha256", payloadHash) |
| 156 | + |
| 157 | + signer := v4.NewSigner() |
| 158 | + region := opts.Region |
| 159 | + if region == "" { |
| 160 | + region = "auto" |
| 161 | + } |
| 162 | + |
| 163 | + err = signer.SignHTTP(ctx, creds, req, payloadHash, "s3", region, time.Now()) |
| 164 | + if err != nil { |
| 165 | + return nil, fmt.Errorf("storage: BundleObjects: failed to sign request: %w", err) |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + resp, err := bundleHTTPClient.Do(req) |
| 170 | + if err != nil { |
| 171 | + return nil, fmt.Errorf("storage: BundleObjects: request failed: %w", err) |
| 172 | + } |
| 173 | + |
| 174 | + if resp.StatusCode >= 400 { |
| 175 | + defer resp.Body.Close() |
| 176 | + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) |
| 177 | + return nil, fmt.Errorf("storage: BundleObjects: HTTP %d: %s", resp.StatusCode, string(errBody)) |
| 178 | + } |
| 179 | + |
| 180 | + return &BundleObjectsOutput{ |
| 181 | + Body: resp.Body, |
| 182 | + ContentType: resp.Header.Get("Content-Type"), |
| 183 | + StatusCode: resp.StatusCode, |
| 184 | + }, nil |
| 185 | +} |
| 186 | + |
| 187 | +func sha256Hex(data []byte) string { |
| 188 | + h := sha256.Sum256(data) |
| 189 | + return hex.EncodeToString(h[:]) |
| 190 | +} |
0 commit comments