From bafb34c3a1dffe5c22908d5e59ca183123102e65 Mon Sep 17 00:00:00 2001 From: ammar68 <ammaa@stud.ntnu.no> Date: Thu, 10 Oct 2024 13:57:03 +0200 Subject: [PATCH] fixed return values of extract and handled error in test cases --- llama/extraction/extract.go | 9 +++++---- llama/extraction/extract_test.go | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/llama/extraction/extract.go b/llama/extraction/extract.go index 199c8f1..025fcba 100644 --- a/llama/extraction/extract.go +++ b/llama/extraction/extract.go @@ -1,6 +1,7 @@ package extraction import ( + "fmt" "strings" ) @@ -20,17 +21,17 @@ var RustPrompt = "The code should be in the Rust programming language. There sho // } // Extract extracts the code snippet between ``` and removes the language identifier. -func Extract(output string) string { +func Extract(output string) (string, error) { parts := strings.Split(output, "```") if len(parts) < 2 { - return "" // Handle the case if format is incorrect: Return empty string + return "", fmt.Errorf("the string wasn't in a proper format") // Handle the case if format is incorrect: Return empty string } // Trim the language identifier like `go` or `rust` from the code code := parts[1] lines := strings.SplitN(code, "\n", 2) if len(lines) > 1 { - return "\n" + lines[1] // Return the code without the first line (language identifier) + return "\n" + lines[1], nil // Return the code without the first line (language identifier) } - return "" + return "", fmt.Errorf("the string doesn't contain any lines") } diff --git a/llama/extraction/extract_test.go b/llama/extraction/extract_test.go index ccecb65..d2d7f3f 100644 --- a/llama/extraction/extract_test.go +++ b/llama/extraction/extract_test.go @@ -71,9 +71,11 @@ var testCases = []struct { func TestExtraction(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - output := Extract(tc.input) + output, err := Extract(tc.input) if output != tc.expected { t.Errorf("Test %s failed: Expected %q, got %q", tc.name, tc.expected, output) + t.Log(err.Error()) + } }) } -- GitLab