From ed82555263a57eb1da7469e3a325c556fe3f7e8a Mon Sep 17 00:00:00 2001 From: Dennis Paul Date: Thu, 16 Jan 2025 15:38:51 +0100 Subject: [PATCH 1/2] implement handleMultipartRedirect() --- mediaapi/routing/download.go | 104 ++++++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 1 deletion(-) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 3a7e7fc9..87d4009c 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -948,7 +948,8 @@ func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSize redirect := p.Header.Get("Location") if redirect != "" { - return 0, nil, fmt.Errorf("Location header is not yet supported") + // Handle redirect + return handleMultipartRedirect(r, redirect, maxFileSizeBytes) } contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) @@ -957,6 +958,107 @@ func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSize return contentLength, reader, err } +// handleMultipartRedirect processes a redirect URL from a multipart response +func handleMultipartRedirect(r *downloadRequest, redirectURL string, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { + const maxRedirects = 10 + redirectCount := 0 + currentURL := redirectURL + + for redirectCount < maxRedirects { + // Validate the redirect URL + parsedURL, err := url.Parse(currentURL) + if err != nil { + return 0, nil, fmt.Errorf("invalid redirect URL: %w", err) + } + + // Security check: Only allow HTTPS URLs unless it's a trusted server + if parsedURL.Scheme != "https" && !isAllowedInsecureRedirect(parsedURL.Host, r.origin) { + return 0, nil, fmt.Errorf("insecure redirect URL: HTTPS required") + } + + // Create a new request for the redirect + req, err := http.NewRequest("GET", currentURL, nil) + if err != nil { + return 0, nil, fmt.Errorf("failed to create redirect request: %w", err) + } + + var resp *http.Response + if r.fedClient != nil { + // Extract media ID from the redirect URL + parsedURL, err := url.Parse(currentURL) + if err != nil { + return 0, nil, fmt.Errorf("invalid redirect URL: %w", err) + } + + // Extract the media ID from the path + pathParts := strings.Split(parsedURL.Path, "/") + if len(pathParts) == 0 { + return 0, nil, fmt.Errorf("invalid media URL path") + } + mediaID := pathParts[len(pathParts)-1] + + // Use the federation client's DownloadMedia method + resp, err = r.fedClient.DownloadMedia(req.Context(), r.origin, spec.ServerName(parsedURL.Host), mediaID) + if err != nil { + return 0, nil, fmt.Errorf("federation client failed to download media: %w", err) + } + } else { + // For non-federation requests, use a regular client + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // Prevent auto-redirect + }, + } + + resp, err = client.Do(req) + if err != nil { + return 0, nil, fmt.Errorf("failed to perform request: %w", err) + } + } + defer resp.Body.Close() + if err != nil { + return 0, nil, fmt.Errorf("failed to follow redirect: %w", err) + } + defer resp.Body.Close() + + // Check if we get another redirect + if resp.StatusCode == http.StatusFound || resp.StatusCode == http.StatusMovedPermanently { + nextURL := resp.Header.Get("Location") + if nextURL == "" { + return 0, nil, fmt.Errorf("redirect response without Location header") + } + currentURL = nextURL + redirectCount++ + continue + } + + // If we got a successful response, process it + if resp.StatusCode == http.StatusOK { + // Check if the response is multipart + contentType := resp.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "multipart/") { + // Handle nested multipart response + return parseMultipartResponse(r, resp, maxFileSizeBytes) + } + + // Handle regular response + return r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes) + } + + return 0, nil, fmt.Errorf("unexpected status code following redirect: %d", resp.StatusCode) + } + + return 0, nil, fmt.Errorf("too many redirects (max %d)", maxRedirects) +} + +// isAllowedInsecureRedirect checks if insecure redirects are allowed for the given host +func isAllowedInsecureRedirect(host string, origin spec.ServerName) bool { + // Implementation depends on your security requirements + // You might want to check against a whitelist of trusted servers + // or compare against the origin server + return string(origin) == host +} + // contentDispositionFor returns the Content-Disposition for a given // content type. func contentDispositionFor(contentType types.ContentType) string { From 03dd120886b58f8ea13ba7938d5ec1b6a9f8ad39 Mon Sep 17 00:00:00 2001 From: Dennis Paul Date: Thu, 16 Jan 2025 18:14:51 +0100 Subject: [PATCH 2/2] implement redirect handler --- mediaapi/routing/download.go | 94 +++++++++++++++++------------------- 1 file changed, 44 insertions(+), 50 deletions(-) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 87d4009c..494e664c 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -9,6 +9,7 @@ package routing import ( "context" "encoding/json" + "time" "fmt" "io" "io/fs" @@ -938,7 +939,6 @@ func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSize if err = json.NewDecoder(p).Decode(&meta); err != nil { return 0, nil, err } - defer p.Close() // nolint: errcheck // Get the actual media content p, err = mr.NextPart() @@ -963,6 +963,14 @@ func handleMultipartRedirect(r *downloadRequest, redirectURL string, maxFileSize const maxRedirects = 10 redirectCount := 0 currentURL := redirectURL + var lastResponse *http.Response + + // Ensure we clean up any response body if we exit early + defer func() { + if lastResponse != nil && lastResponse.Body != nil { + lastResponse.Body.Close() + } + }() for redirectCount < maxRedirects { // Validate the redirect URL @@ -971,55 +979,31 @@ func handleMultipartRedirect(r *downloadRequest, redirectURL string, maxFileSize return 0, nil, fmt.Errorf("invalid redirect URL: %w", err) } - // Security check: Only allow HTTPS URLs unless it's a trusted server - if parsedURL.Scheme != "https" && !isAllowedInsecureRedirect(parsedURL.Host, r.origin) { - return 0, nil, fmt.Errorf("insecure redirect URL: HTTPS required") - } - // Create a new request for the redirect req, err := http.NewRequest("GET", currentURL, nil) if err != nil { return 0, nil, fmt.Errorf("failed to create redirect request: %w", err) } - var resp *http.Response - if r.fedClient != nil { - // Extract media ID from the redirect URL - parsedURL, err := url.Parse(currentURL) - if err != nil { - return 0, nil, fmt.Errorf("invalid redirect URL: %w", err) - } - - // Extract the media ID from the path - pathParts := strings.Split(parsedURL.Path, "/") - if len(pathParts) == 0 { - return 0, nil, fmt.Errorf("invalid media URL path") - } - mediaID := pathParts[len(pathParts)-1] - - // Use the federation client's DownloadMedia method - resp, err = r.fedClient.DownloadMedia(req.Context(), r.origin, spec.ServerName(parsedURL.Host), mediaID) - if err != nil { - return 0, nil, fmt.Errorf("federation client failed to download media: %w", err) - } - } else { - // For non-federation requests, use a regular client - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse // Prevent auto-redirect - }, - } + // Close the previous response body before making a new request + if lastResponse != nil { + lastResponse.Body.Close() + lastResponse = nil + } - resp, err = client.Do(req) - if err != nil { - return 0, nil, fmt.Errorf("failed to perform request: %w", err) - } + // Use a regular client for redirects, as they might point to external storage + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // Prevent auto-redirect + }, + Timeout: 30 * time.Second, } - defer resp.Body.Close() + + resp, err := client.Do(req) if err != nil { return 0, nil, fmt.Errorf("failed to follow redirect: %w", err) } - defer resp.Body.Close() + lastResponse = resp // Check if we get another redirect if resp.StatusCode == http.StatusFound || resp.StatusCode == http.StatusMovedPermanently { @@ -1027,6 +1011,18 @@ func handleMultipartRedirect(r *downloadRequest, redirectURL string, maxFileSize if nextURL == "" { return 0, nil, fmt.Errorf("redirect response without Location header") } + + // Handle relative URLs + nextParsedURL, err := url.Parse(nextURL) + if err != nil { + return 0, nil, fmt.Errorf("invalid redirect URL: %w", err) + } + + if !nextParsedURL.IsAbs() { + nextParsedURL = parsedURL.ResolveReference(nextParsedURL) + nextURL = nextParsedURL.String() + } + currentURL = nextURL redirectCount++ continue @@ -1037,12 +1033,18 @@ func handleMultipartRedirect(r *downloadRequest, redirectURL string, maxFileSize // Check if the response is multipart contentType := resp.Header.Get("Content-Type") if strings.HasPrefix(contentType, "multipart/") { - // Handle nested multipart response + // For multipart responses, we need to keep the response body open + // The caller will be responsible for closing it + lastResponse = nil // Don't close in defer return parseMultipartResponse(r, resp, maxFileSizeBytes) } - // Handle regular response - return r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes) + // For regular responses, create a new reader that will close the response body + body := resp.Body + lastResponse = nil // Don't close in defer + reader := io.NopCloser(body) + + return r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), reader, maxFileSizeBytes) } return 0, nil, fmt.Errorf("unexpected status code following redirect: %d", resp.StatusCode) @@ -1051,14 +1053,6 @@ func handleMultipartRedirect(r *downloadRequest, redirectURL string, maxFileSize return 0, nil, fmt.Errorf("too many redirects (max %d)", maxRedirects) } -// isAllowedInsecureRedirect checks if insecure redirects are allowed for the given host -func isAllowedInsecureRedirect(host string, origin spec.ServerName) bool { - // Implementation depends on your security requirements - // You might want to check against a whitelist of trusted servers - // or compare against the origin server - return string(origin) == host -} - // contentDispositionFor returns the Content-Disposition for a given // content type. func contentDispositionFor(contentType types.ContentType) string {