diff --git a/client/clientimpl_test.go b/client/clientimpl_test.go index dea6bd6c..cf70c6a8 100644 --- a/client/clientimpl_test.go +++ b/client/clientimpl_test.go @@ -1439,27 +1439,6 @@ const packageFileURL = "/validfile.pkg" var packageFileContent = []byte("Package File Content") -func createDownloadSrv(t *testing.T) *httptest.Server { - m := http.NewServeMux() - m.HandleFunc(packageFileURL, - func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, err := w.Write(packageFileContent) - assert.NoError(t, err) - }) - - srv := httptest.NewServer(m) - - u, err := url.Parse(srv.URL) - if err != nil { - t.Fatal(err) - } - endpoint := u.Host - testhelpers.WaitForEndpoint(endpoint) - - return srv -} - func createPackageTestCase(name string, downloadSrv *httptest.Server) packageTestCase { return packageTestCase{ name: name, @@ -1505,14 +1484,43 @@ func createPackageTestCase(name string, downloadSrv *httptest.Server) packageTes } } +// Mock download server that checks headers in the request. +func createDownloadSrv(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for the Authorization header in the successWithHeaders case. + if r.Header.Get("Authorization") == "Bearer test-token" { + t.Logf("Authorization header successfully set.") + } else { + t.Errorf("Expected Authorization header to be 'Bearer test-token', but got '%s'", r.Header.Get("Authorization")) + } + + // Handle file not found scenario. + if r.URL.Path == "/notfound" { + http.Error(w, "Not found", http.StatusNotFound) + return + } + + // Simulate a successful file download for valid requests. + w.WriteHeader(http.StatusOK) + w.Write([]byte("file content")) + })) +} + func TestUpdatePackages(t *testing.T) { downloadSrv := createDownloadSrv(t) defer downloadSrv.Close() - // A success case. var tests []packageTestCase - tests = append(tests, createPackageTestCase("success", downloadSrv)) + + // Success case with headers. + successWithHeaders := createPackageTestCase("success with headers", downloadSrv) + successWithHeaders.available.Packages["packageWithHeaders"].File.Headers = &protobufs.Headers{ + Headers: []*protobufs.Header{ + {Key: "Authorization", Value: "Bearer test-token"}, + }, + } + tests = append(tests, successWithHeaders) // A case when downloading the file fails because the URL is incorrect. notFound := createPackageTestCase("downloadable file not found", downloadSrv) diff --git a/client/internal/packagessyncer.go b/client/internal/packagessyncer.go index c02828d4..a90052f4 100644 --- a/client/internal/packagessyncer.go +++ b/client/internal/packagessyncer.go @@ -280,6 +280,13 @@ func (s *packagesSyncer) downloadFile(ctx context.Context, pkgName string, file return fmt.Errorf("cannot download file from %s: %v", file.DownloadUrl, err) } + // Add optional headers if they exist + if file.Headers != nil { + for _, header := range file.Headers.Headers { + req.Header.Add(header.Key, header.Value) + } + } + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("cannot download file from %s: %v", file.DownloadUrl, err)