From bafb34c3a1dffe5c22908d5e59ca183123102e65 Mon Sep 17 00:00:00 2001
From: ammar68 <ammaa@stud.ntnu.no>
Date: Thu, 10 Oct 2024 13:57:03 +0200
Subject: [PATCH] fixed return values of extract and handled error in test
 cases

---
 llama/extraction/extract.go      | 9 +++++----
 llama/extraction/extract_test.go | 4 +++-
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/llama/extraction/extract.go b/llama/extraction/extract.go
index 199c8f1..025fcba 100644
--- a/llama/extraction/extract.go
+++ b/llama/extraction/extract.go
@@ -1,6 +1,7 @@
 package extraction
 
 import (
+	"fmt"
 	"strings"
 )
 
@@ -20,17 +21,17 @@ var RustPrompt = "The code should be in the Rust programming language. There sho
 // }
 
 // Extract extracts the code snippet between ``` and removes the language identifier.
-func Extract(output string) string {
+func Extract(output string) (string, error) {
 	parts := strings.Split(output, "```")
 	if len(parts) < 2 {
-		return "" // Handle the case if format is incorrect: Return empty string
+		return "", fmt.Errorf("the string wasn't in a proper format") // Handle the case if format is incorrect: Return empty string
 	}
 
 	// Trim the language identifier like `go` or `rust` from the code
 	code := parts[1]
 	lines := strings.SplitN(code, "\n", 2)
 	if len(lines) > 1 {
-		return "\n" + lines[1] // Return the code without the first line (language identifier)
+		return "\n" + lines[1], nil // Return the code without the first line (language identifier)
 	}
-	return ""
+	return "", fmt.Errorf("the string doesn't contain any lines")
 }
diff --git a/llama/extraction/extract_test.go b/llama/extraction/extract_test.go
index ccecb65..d2d7f3f 100644
--- a/llama/extraction/extract_test.go
+++ b/llama/extraction/extract_test.go
@@ -71,9 +71,11 @@ var testCases = []struct {
 func TestExtraction(t *testing.T) {
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
-			output := Extract(tc.input)
+			output, err := Extract(tc.input)
 			if output != tc.expected {
 				t.Errorf("Test %s failed: Expected %q, got %q", tc.name, tc.expected, output)
+				t.Log(err.Error())
+
 			}
 		})
 	}
-- 
GitLab