From 9631958a82c70f30421fc4e292f700ec8881805e Mon Sep 17 00:00:00 2001
From: Chongyi Zheng <harryzheng25@gmail.com>
Date: Mon, 18 Sep 2023 04:40:50 -0400
Subject: [PATCH] Refactor lfs requests (#26783)

- Refactor lfs request code
- The original code uses `performRequest` function to create the
request, uses a callback to modify the request, and then send the
request.
- Now it's replaced with `createRequest` that only creates request and
`performRequest` that only sends the request.
- Reuse `createRequest` and `performRequest` in `http_client.go` and
`transferadapter.go`

---------

Co-authored-by: wxiaoguang <wxiaoguang@gmail.com>
---
 modules/lfs/filesystem_client.go |  12 ++--
 modules/lfs/http_client.go       | 104 +++++++++++++++++++----------
 modules/lfs/http_client_test.go  |  36 ++++++-----
 modules/lfs/pointer.go           |   4 +-
 modules/lfs/transferadapter.go   | 108 +++++++++----------------------
 modules/util/path.go             |   1 +
 6 files changed, 127 insertions(+), 138 deletions(-)

diff --git a/modules/lfs/filesystem_client.go b/modules/lfs/filesystem_client.go
index 835551e00c..3503a9effc 100644
--- a/modules/lfs/filesystem_client.go
+++ b/modules/lfs/filesystem_client.go
@@ -15,7 +15,7 @@ import (
 
 // FilesystemClient is used to read LFS data from a filesystem path
 type FilesystemClient struct {
-	lfsdir string
+	lfsDir string
 }
 
 // BatchSize returns the preferred size of batchs to process
@@ -25,16 +25,12 @@ func (c *FilesystemClient) BatchSize() int {
 
 func newFilesystemClient(endpoint *url.URL) *FilesystemClient {
 	path, _ := util.FileURLToPath(endpoint)
-
-	lfsdir := filepath.Join(path, "lfs", "objects")
-
-	client := &FilesystemClient{lfsdir}
-
-	return client
+	lfsDir := filepath.Join(path, "lfs", "objects")
+	return &FilesystemClient{lfsDir}
 }
 
 func (c *FilesystemClient) objectPath(oid string) string {
-	return filepath.Join(c.lfsdir, oid[0:2], oid[2:4], oid)
+	return filepath.Join(c.lfsDir, oid[0:2], oid[2:4], oid)
 }
 
 // Download reads the specific LFS object from the target path
diff --git a/modules/lfs/http_client.go b/modules/lfs/http_client.go
index ec0d6269bd..de0b1e4fed 100644
--- a/modules/lfs/http_client.go
+++ b/modules/lfs/http_client.go
@@ -8,6 +8,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"io"
 	"net/http"
 	"net/url"
 	"strings"
@@ -17,7 +18,7 @@ import (
 	"code.gitea.io/gitea/modules/proxy"
 )
 
-const batchSize = 20
+const httpBatchSize = 20
 
 // HTTPClient is used to communicate with the LFS server
 // https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
@@ -29,7 +30,7 @@ type HTTPClient struct {
 
 // BatchSize returns the preferred size of batchs to process
 func (c *HTTPClient) BatchSize() int {
-	return batchSize
+	return httpBatchSize
 }
 
 func newHTTPClient(endpoint *url.URL, httpTransport *http.Transport) *HTTPClient {
@@ -43,28 +44,25 @@ func newHTTPClient(endpoint *url.URL, httpTransport *http.Transport) *HTTPClient
 		Transport: httpTransport,
 	}
 
-	client := &HTTPClient{
-		client:    hc,
-		endpoint:  strings.TrimSuffix(endpoint.String(), "/"),
-		transfers: make(map[string]TransferAdapter),
-	}
-
 	basic := &BasicTransferAdapter{hc}
-
-	client.transfers[basic.Name()] = basic
+	client := &HTTPClient{
+		client:   hc,
+		endpoint: strings.TrimSuffix(endpoint.String(), "/"),
+		transfers: map[string]TransferAdapter{
+			basic.Name(): basic,
+		},
+	}
 
 	return client
 }
 
 func (c *HTTPClient) transferNames() []string {
 	keys := make([]string, len(c.transfers))
-
 	i := 0
 	for k := range c.transfers {
 		keys[i] = k
 		i++
 	}
-
 	return keys
 }
 
@@ -74,7 +72,6 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin
 	url := fmt.Sprintf("%s/objects/batch", c.endpoint)
 
 	request := &BatchRequest{operation, c.transferNames(), nil, objects}
-
 	payload := new(bytes.Buffer)
 	err := json.NewEncoder(payload).Encode(request)
 	if err != nil {
@@ -82,32 +79,17 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin
 		return nil, err
 	}
 
-	log.Trace("Calling: %s", url)
-
-	req, err := http.NewRequestWithContext(ctx, "POST", url, payload)
+	req, err := createRequest(ctx, http.MethodPost, url, map[string]string{"Content-Type": MediaType}, payload)
 	if err != nil {
-		log.Error("Error creating request: %v", err)
 		return nil, err
 	}
-	req.Header.Set("Content-type", MediaType)
-	req.Header.Set("Accept", MediaType)
 
-	res, err := c.client.Do(req)
+	res, err := performRequest(ctx, c.client, req)
 	if err != nil {
-		select {
-		case <-ctx.Done():
-			return nil, ctx.Err()
-		default:
-		}
-		log.Error("Error while processing request: %v", err)
 		return nil, err
 	}
 	defer res.Body.Close()
 
-	if res.StatusCode != http.StatusOK {
-		return nil, fmt.Errorf("Unexpected server response: %s", res.Status)
-	}
-
 	var response BatchResponse
 	err = json.NewDecoder(res.Body).Decode(&response)
 	if err != nil {
@@ -177,7 +159,7 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc
 			link, ok := object.Actions["upload"]
 			if !ok {
 				log.Debug("%+v", object)
-				return errors.New("Missing action 'upload'")
+				return errors.New("missing action 'upload'")
 			}
 
 			content, err := uc(object.Pointer, nil)
@@ -187,8 +169,6 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc
 
 			err = transferAdapter.Upload(ctx, link, object.Pointer, content)
 
-			content.Close()
-
 			if err != nil {
 				return err
 			}
@@ -203,7 +183,7 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc
 			link, ok := object.Actions["download"]
 			if !ok {
 				log.Debug("%+v", object)
-				return errors.New("Missing action 'download'")
+				return errors.New("missing action 'download'")
 			}
 
 			content, err := transferAdapter.Download(ctx, link)
@@ -219,3 +199,59 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc
 
 	return nil
 }
+
+// createRequest creates a new request, and sets the headers.
+func createRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader) (*http.Request, error) {
+	log.Trace("createRequest: %s", url)
+	req, err := http.NewRequestWithContext(ctx, method, url, body)
+	if err != nil {
+		log.Error("Error creating request: %v", err)
+		return nil, err
+	}
+
+	for key, value := range headers {
+		req.Header.Set(key, value)
+	}
+	req.Header.Set("Accept", MediaType)
+
+	return req, nil
+}
+
+// performRequest sends a request, optionally performs a callback on the request and returns the response.
+// If the status code is 200, the response is returned, and it will contain a non-nil Body.
+// Otherwise, it will return an error, and the Body will be nil or closed.
+func performRequest(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
+	log.Trace("performRequest: %s", req.URL)
+	res, err := client.Do(req)
+	if err != nil {
+		select {
+		case <-ctx.Done():
+			return res, ctx.Err()
+		default:
+		}
+		log.Error("Error while processing request: %v", err)
+		return res, err
+	}
+
+	if res.StatusCode != http.StatusOK {
+		defer res.Body.Close()
+		return res, handleErrorResponse(res)
+	}
+
+	return res, nil
+}
+
+func handleErrorResponse(resp *http.Response) error {
+	var er ErrorResponse
+	err := json.NewDecoder(resp.Body).Decode(&er)
+	if err != nil {
+		if err == io.EOF {
+			return io.ErrUnexpectedEOF
+		}
+		log.Error("Error decoding json: %v", err)
+		return err
+	}
+
+	log.Trace("ErrorResponse: %v", er)
+	return errors.New(er.Message)
+}
diff --git a/modules/lfs/http_client_test.go b/modules/lfs/http_client_test.go
index cb71b9008a..7459d9c0c9 100644
--- a/modules/lfs/http_client_test.go
+++ b/modules/lfs/http_client_test.go
@@ -177,7 +177,7 @@ func TestHTTPClientDownload(t *testing.T) {
 		// case 0
 		{
 			endpoint:      "https://status-not-ok.io",
-			expectederror: "Unexpected server response: ",
+			expectederror: io.ErrUnexpectedEOF.Error(),
 		},
 		// case 1
 		{
@@ -207,7 +207,7 @@ func TestHTTPClientDownload(t *testing.T) {
 		// case 6
 		{
 			endpoint:      "https://empty-actions-map.io",
-			expectederror: "Missing action 'download'",
+			expectederror: "missing action 'download'",
 		},
 		// case 7
 		{
@@ -217,27 +217,28 @@ func TestHTTPClientDownload(t *testing.T) {
 		// case 8
 		{
 			endpoint:      "https://upload-actions-map.io",
-			expectederror: "Missing action 'download'",
+			expectederror: "missing action 'download'",
 		},
 		// case 9
 		{
 			endpoint:      "https://verify-actions-map.io",
-			expectederror: "Missing action 'download'",
+			expectederror: "missing action 'download'",
 		},
 		// case 10
 		{
 			endpoint:      "https://unknown-actions-map.io",
-			expectederror: "Missing action 'download'",
+			expectederror: "missing action 'download'",
 		},
 	}
 
 	for n, c := range cases {
 		client := &HTTPClient{
-			client:    hc,
-			endpoint:  c.endpoint,
-			transfers: make(map[string]TransferAdapter),
+			client:   hc,
+			endpoint: c.endpoint,
+			transfers: map[string]TransferAdapter{
+				"dummy": dummy,
+			},
 		}
-		client.transfers["dummy"] = dummy
 
 		err := client.Download(context.Background(), []Pointer{p}, func(p Pointer, content io.ReadCloser, objectError error) error {
 			if objectError != nil {
@@ -284,7 +285,7 @@ func TestHTTPClientUpload(t *testing.T) {
 		// case 0
 		{
 			endpoint:      "https://status-not-ok.io",
-			expectederror: "Unexpected server response: ",
+			expectederror: io.ErrUnexpectedEOF.Error(),
 		},
 		// case 1
 		{
@@ -319,7 +320,7 @@ func TestHTTPClientUpload(t *testing.T) {
 		// case 7
 		{
 			endpoint:      "https://download-actions-map.io",
-			expectederror: "Missing action 'upload'",
+			expectederror: "missing action 'upload'",
 		},
 		// case 8
 		{
@@ -329,22 +330,23 @@ func TestHTTPClientUpload(t *testing.T) {
 		// case 9
 		{
 			endpoint:      "https://verify-actions-map.io",
-			expectederror: "Missing action 'upload'",
+			expectederror: "missing action 'upload'",
 		},
 		// case 10
 		{
 			endpoint:      "https://unknown-actions-map.io",
-			expectederror: "Missing action 'upload'",
+			expectederror: "missing action 'upload'",
 		},
 	}
 
 	for n, c := range cases {
 		client := &HTTPClient{
-			client:    hc,
-			endpoint:  c.endpoint,
-			transfers: make(map[string]TransferAdapter),
+			client:   hc,
+			endpoint: c.endpoint,
+			transfers: map[string]TransferAdapter{
+				"dummy": dummy,
+			},
 		}
-		client.transfers["dummy"] = dummy
 
 		err := client.Upload(context.Background(), []Pointer{p}, func(p Pointer, objectError error) (io.ReadCloser, error) {
 			return io.NopCloser(new(bytes.Buffer)), objectError
diff --git a/modules/lfs/pointer.go b/modules/lfs/pointer.go
index d7653e836c..3e5bb8f91d 100644
--- a/modules/lfs/pointer.go
+++ b/modules/lfs/pointer.go
@@ -29,10 +29,10 @@ const (
 
 var (
 	// ErrMissingPrefix occurs if the content lacks the LFS prefix
-	ErrMissingPrefix = errors.New("Content lacks the LFS prefix")
+	ErrMissingPrefix = errors.New("content lacks the LFS prefix")
 
 	// ErrInvalidStructure occurs if the content has an invalid structure
-	ErrInvalidStructure = errors.New("Content has an invalid structure")
+	ErrInvalidStructure = errors.New("content has an invalid structure")
 
 	// ErrInvalidOIDFormat occurs if the oid has an invalid format
 	ErrInvalidOIDFormat = errors.New("OID has an invalid format")
diff --git a/modules/lfs/transferadapter.go b/modules/lfs/transferadapter.go
index 649497aabb..d425b91946 100644
--- a/modules/lfs/transferadapter.go
+++ b/modules/lfs/transferadapter.go
@@ -6,8 +6,6 @@ package lfs
 import (
 	"bytes"
 	"context"
-	"errors"
-	"fmt"
 	"io"
 	"net/http"
 
@@ -15,7 +13,7 @@ import (
 	"code.gitea.io/gitea/modules/log"
 )
 
-// TransferAdapter represents an adapter for downloading/uploading LFS objects
+// TransferAdapter represents an adapter for downloading/uploading LFS objects.
 type TransferAdapter interface {
 	Name() string
 	Download(ctx context.Context, l *Link) (io.ReadCloser, error)
@@ -23,41 +21,48 @@ type TransferAdapter interface {
 	Verify(ctx context.Context, l *Link, p Pointer) error
 }
 
-// BasicTransferAdapter implements the "basic" adapter
+// BasicTransferAdapter implements the "basic" adapter.
 type BasicTransferAdapter struct {
 	client *http.Client
 }
 
-// Name returns the name of the adapter
+// Name returns the name of the adapter.
 func (a *BasicTransferAdapter) Name() string {
 	return "basic"
 }
 
-// Download reads the download location and downloads the data
+// Download reads the download location and downloads the data.
 func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) {
-	resp, err := a.performRequest(ctx, "GET", l, nil, nil)
+	req, err := createRequest(ctx, http.MethodGet, l.Href, l.Header, nil)
+	if err != nil {
+		return nil, err
+	}
+	resp, err := performRequest(ctx, a.client, req)
 	if err != nil {
 		return nil, err
 	}
 	return resp.Body, nil
 }
 
-// Upload sends the content to the LFS server
+// Upload sends the content to the LFS server.
 func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error {
-	_, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) {
-		if len(req.Header.Get("Content-Type")) == 0 {
-			req.Header.Set("Content-Type", "application/octet-stream")
-		}
-
-		if req.Header.Get("Transfer-Encoding") == "chunked" {
-			req.TransferEncoding = []string{"chunked"}
-		}
-
-		req.ContentLength = p.Size
-	})
+	req, err := createRequest(ctx, http.MethodPut, l.Href, l.Header, r)
 	if err != nil {
 		return err
 	}
+	if req.Header.Get("Content-Type") == "" {
+		req.Header.Set("Content-Type", "application/octet-stream")
+	}
+	if req.Header.Get("Transfer-Encoding") == "chunked" {
+		req.TransferEncoding = []string{"chunked"}
+	}
+	req.ContentLength = p.Size
+
+	res, err := performRequest(ctx, a.client, req)
+	if err != nil {
+		return err
+	}
+	defer res.Body.Close()
 	return nil
 }
 
@@ -69,66 +74,15 @@ func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) e
 		return err
 	}
 
-	_, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) {
-		req.Header.Set("Content-Type", MediaType)
-	})
+	req, err := createRequest(ctx, http.MethodPost, l.Href, l.Header, bytes.NewReader(b))
 	if err != nil {
 		return err
 	}
+	req.Header.Set("Content-Type", MediaType)
+	res, err := performRequest(ctx, a.client, req)
+	if err != nil {
+		return err
+	}
+	defer res.Body.Close()
 	return nil
 }
-
-func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) {
-	log.Trace("Calling: %s %s", method, l.Href)
-
-	req, err := http.NewRequestWithContext(ctx, method, l.Href, body)
-	if err != nil {
-		log.Error("Error creating request: %v", err)
-		return nil, err
-	}
-	for key, value := range l.Header {
-		req.Header.Set(key, value)
-	}
-	req.Header.Set("Accept", MediaType)
-
-	if callback != nil {
-		callback(req)
-	}
-
-	res, err := a.client.Do(req)
-	if err != nil {
-		select {
-		case <-ctx.Done():
-			return res, ctx.Err()
-		default:
-		}
-		log.Error("Error while processing request: %v", err)
-		return res, err
-	}
-
-	if res.StatusCode != http.StatusOK {
-		return res, handleErrorResponse(res)
-	}
-
-	return res, nil
-}
-
-func handleErrorResponse(resp *http.Response) error {
-	defer resp.Body.Close()
-
-	er, err := decodeResponseError(resp.Body)
-	if err != nil {
-		return fmt.Errorf("Request failed with status %s", resp.Status)
-	}
-	log.Trace("ErrorRespone: %v", er)
-	return errors.New(er.Message)
-}
-
-func decodeResponseError(r io.Reader) (ErrorResponse, error) {
-	var er ErrorResponse
-	err := json.NewDecoder(r).Decode(&er)
-	if err != nil {
-		log.Error("Error decoding json: %v", err)
-	}
-	return er, err
-}
diff --git a/modules/util/path.go b/modules/util/path.go
index 58258560dd..e8537fb6b9 100644
--- a/modules/util/path.go
+++ b/modules/util/path.go
@@ -225,6 +225,7 @@ func isOSWindows() bool {
 var driveLetterRegexp = regexp.MustCompile("/[A-Za-z]:/")
 
 // FileURLToPath extracts the path information from a file://... url.
+// It returns an error only if the URL is not a file URL.
 func FileURLToPath(u *url.URL) (string, error) {
 	if u.Scheme != "file" {
 		return "", errors.New("URL scheme is not 'file': " + u.String())