forked from stacklok/codegate
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add integration tests for the copilot provider
Since the copilot provider is a proxy, we add a "requester" module that depending on the provider makes a request either using raw python requests like earlier or by setting a proxy and using a CA cert file. To be able to add more tests, we also add more kinds of checks, in addition to the existing one which makes sure the reply is like the expected one using cosine distance, we also add checks that make sure the LLM reply contains or doesn't contain a string. We use those to add a test that ensures that the copilot provider chat works and that the copilot chat refuses to generate code snippet with a malicious package. To be able to run a subset of tests, we also add the ability to select a subset of tests based on a provider (`codegate_providers`) or the test name (`codegate_test_names`) These serve as the base for further integration tests. To run them, call: ``` CODEGATE_PROVIDERS=copilot \ CA_CERT_FILE=/Users/you/devel/codegate/codegate_volume/certs/ca.crt \ ENV_COPILOT_KEY=your-openapi-key \ python tests/integration/integration_tests.py ``` Related: stacklok#402
- Loading branch information
Showing
4 changed files
with
313 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
import structlog | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
from codegate.inference.inference_engine import LlamaCppInferenceEngine | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
|
||
class BaseCheck(ABC): | ||
def __init__(self, test_name: str): | ||
self.test_name = test_name | ||
|
||
@abstractmethod | ||
async def run_check(self, parsed_response: str, test_data: dict) -> bool: | ||
pass | ||
|
||
|
||
class CheckLoader: | ||
@staticmethod | ||
def load(test_data: dict) -> List[BaseCheck]: | ||
test_name = test_data.get("name") | ||
checks = [] | ||
if test_data.get(DistanceCheck.KEY): | ||
checks.append(DistanceCheck(test_name)) | ||
if test_data.get(ContainsCheck.KEY): | ||
checks.append(ContainsCheck(test_name)) | ||
if test_data.get(DoesNotContainCheck.KEY): | ||
checks.append(DoesNotContainCheck(test_name)) | ||
|
||
return checks | ||
|
||
|
||
class DistanceCheck(BaseCheck): | ||
KEY = "likes" | ||
|
||
def __init__(self, test_name: str): | ||
super().__init__(test_name) | ||
self.inference_engine = LlamaCppInferenceEngine() | ||
self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf" | ||
|
||
async def _calculate_string_similarity(self, str1, str2): | ||
vector1 = await self.inference_engine.embed(self.embedding_model, [str1]) | ||
vector2 = await self.inference_engine.embed(self.embedding_model, [str2]) | ||
similarity = cosine_similarity(vector1, vector2) | ||
return similarity[0] | ||
|
||
async def run_check(self, parsed_response: str, test_data: dict) -> bool: | ||
similarity = await self._calculate_string_similarity( | ||
parsed_response, test_data[DistanceCheck.KEY] | ||
) | ||
if similarity < 0.8: | ||
logger.error(f"Test {self.test_name} failed") | ||
logger.error(f"Similarity: {similarity}") | ||
logger.error(f"Response: {parsed_response}") | ||
logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}") | ||
return False | ||
return True | ||
|
||
|
||
class ContainsCheck(BaseCheck): | ||
KEY = "contains" | ||
|
||
async def run_check(self, parsed_response: str, test_data: dict) -> bool: | ||
if test_data[ContainsCheck.KEY].strip() not in parsed_response: | ||
logger.error(f"Test {self.test_name} failed") | ||
logger.error(f"Response: {parsed_response}") | ||
logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'") | ||
return False | ||
return True | ||
|
||
|
||
class DoesNotContainCheck(BaseCheck): | ||
KEY = "does_not_contain" | ||
|
||
async def run_check(self, parsed_response: str, test_data: dict) -> bool: | ||
if test_data[DoesNotContainCheck.KEY].strip() in parsed_response: | ||
logger.error(f"Test {self.test_name} failed") | ||
logger.error(f"Response: {parsed_response}") | ||
logger.error( | ||
f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'" | ||
) | ||
return False | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import json | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import requests | ||
import structlog | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
|
||
class BaseRequester(ABC): | ||
@abstractmethod | ||
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: | ||
pass | ||
|
||
|
||
class StandardRequester(BaseRequester): | ||
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: | ||
# Ensure Content-Type is always set correctly | ||
headers["Content-Type"] = "application/json" | ||
|
||
# Explicitly serialize to JSON string | ||
json_data = json.dumps(data) | ||
|
||
return requests.post( | ||
url, headers=headers, data=json_data # Use data instead of json parameter | ||
) | ||
|
||
|
||
class CopilotRequester(BaseRequester): | ||
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: | ||
# Ensure Content-Type is always set correctly | ||
headers["Content-Type"] = "application/json" | ||
|
||
# Explicitly serialize to JSON string | ||
json_data = json.dumps(data) | ||
|
||
return requests.post( | ||
url, | ||
data=json_data, # Use data instead of json parameter | ||
headers=headers, | ||
proxies={"https": "https://localhost:8990", "http": "http://localhost:8990"}, | ||
verify=os.environ.get("CA_CERT_FILE"), | ||
stream=True, | ||
) | ||
|
||
|
||
class RequesterFactory: | ||
@staticmethod | ||
def create_requester(provider: str) -> BaseRequester: | ||
if provider.lower() == "copilot": | ||
return CopilotRequester() | ||
return StandardRequester() |
Oops, something went wrong.