Skip to content

Commit

Permalink
Add integration tests for the copilot provider
Browse files Browse the repository at this point in the history
Since the copilot provider is a proxy, we add a "requester" module that
depending on the provider makes a request either using raw python
requests like earlier or by setting a proxy and using a CA cert file.

To be able to add more tests, we also add more kinds of checks, in
addition to the existing one which makes sure the reply is like the
expected one using cosine distance, we also add checks that make sure
the LLM reply contains or doesn't contain a string.

We use those to add a test that ensures that the copilot provider chat
works and that the copilot chat refuses to generate code snippet with a
malicious package.

To be able to run a subset of tests, we also add the ability to select a
subset of tests based on a provider (`codegate_providers`) or the test name (`codegate_test_names`)

These serve as the base for further integration tests.

To run them, call:
```
CODEGATE_PROVIDERS=copilot \
CA_CERT_FILE=/Users/you/devel/codegate/codegate_volume/certs/ca.crt \
ENV_COPILOT_KEY=your-openapi-key \
python tests/integration/integration_tests.py
```

Related: #402
  • Loading branch information
jhrozek committed Jan 9, 2025
1 parent 19ffa83 commit ff4a3a7
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 52 deletions.
86 changes: 86 additions & 0 deletions tests/integration/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from abc import ABC, abstractmethod
from typing import List

import structlog
from sklearn.metrics.pairwise import cosine_similarity

from codegate.inference.inference_engine import LlamaCppInferenceEngine

logger = structlog.get_logger("codegate")


class BaseCheck(ABC):
def __init__(self, test_name: str):
self.test_name = test_name

@abstractmethod
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
pass


class CheckLoader:
@staticmethod
def load(test_data: dict) -> List[BaseCheck]:
test_name = test_data.get("name")
checks = []
if test_data.get(DistanceCheck.KEY):
checks.append(DistanceCheck(test_name))
if test_data.get(ContainsCheck.KEY):
checks.append(ContainsCheck(test_name))
if test_data.get(DoesNotContainCheck.KEY):
checks.append(DoesNotContainCheck(test_name))

return checks


class DistanceCheck(BaseCheck):
KEY = "likes"

def __init__(self, test_name: str):
super().__init__(test_name)
self.inference_engine = LlamaCppInferenceEngine()
self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"

async def _calculate_string_similarity(self, str1, str2):
vector1 = await self.inference_engine.embed(self.embedding_model, [str1])
vector2 = await self.inference_engine.embed(self.embedding_model, [str2])
similarity = cosine_similarity(vector1, vector2)
return similarity[0]

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
similarity = await self._calculate_string_similarity(
parsed_response, test_data[DistanceCheck.KEY]
)
if similarity < 0.8:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Similarity: {similarity}")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}")
return False
return True


class ContainsCheck(BaseCheck):
KEY = "contains"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
if test_data[ContainsCheck.KEY].strip() not in parsed_response:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'")
return False
return True


class DoesNotContainCheck(BaseCheck):
KEY = "does_not_contain"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
if test_data[DoesNotContainCheck.KEY].strip() in parsed_response:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Response: {parsed_response}")
logger.error(
f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'"
)
return False
return True
164 changes: 122 additions & 42 deletions tests/integration/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,50 @@
import json
import os
import re
from typing import Optional

import requests
import structlog
import yaml
from checks import CheckLoader
from dotenv import find_dotenv, load_dotenv
from sklearn.metrics.pairwise import cosine_similarity

from codegate.inference.inference_engine import LlamaCppInferenceEngine
from requesters import RequesterFactory

logger = structlog.get_logger("codegate")


class CodegateTestRunner:
def __init__(self):
self.inference_engine = LlamaCppInferenceEngine()
self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"
self.requester_factory = RequesterFactory()

def call_codegate(
self, url: str, headers: dict, data: dict, provider: str
) -> Optional[requests.Response]:
logger.debug(f"Creating requester for provider: {provider}")
requester = self.requester_factory.create_requester(provider)
logger.debug(f"Using requester type: {requester.__class__.__name__}")

logger.debug(f"Making request to URL: {url}")
logger.debug(f"Headers: {headers}")
logger.debug(f"Data: {data}")

response = requester.make_request(url, headers, data)

# Enhanced response logging
if response is not None:

if response.status_code != 200:
logger.debug(f"Response error status: {response.status_code}")
logger.debug(f"Response error headers: {dict(response.headers)}")
try:
error_content = response.json()
logger.error(f"Request error as JSON: {error_content}")
except ValueError:
# If not JSON, try to get raw text
logger.error(f"Raw request error: {response.text}")
else:
logger.error("No response received")

@staticmethod
def call_codegate(url, headers, data):
response = None
try:
response = requests.post(url, headers=headers, json=data)
except Exception as e:
logger.exception("An error occurred: %s", e)
return response

@staticmethod
Expand All @@ -50,6 +70,8 @@ def parse_response_message(response, streaming=True):

message_content = None
if "choices" in json_line:
if "finish_reason" in json_line["choices"][0]:
break
if "delta" in json_line["choices"][0]:
message_content = json_line["choices"][0]["delta"].get("content", "")
elif "text" in json_line["choices"][0]:
Expand All @@ -75,12 +97,6 @@ def parse_response_message(response, streaming=True):

return response_message

async def calculate_string_similarity(self, str1, str2):
vector1 = await self.inference_engine.embed(self.embedding_model, [str1])
vector2 = await self.inference_engine.embed(self.embedding_model, [str2])
similarity = cosine_similarity(vector1, vector2)
return similarity[0]

@staticmethod
def replace_env_variables(input_string, env):
"""
Expand All @@ -103,51 +119,115 @@ def replacement(match):
pattern = r"ENV\w*"
return re.sub(pattern, replacement, input_string)

async def run_test(self, test, test_headers):
async def run_test(self, test: dict, test_headers: dict) -> None:
test_name = test["name"]
url = test["url"]
data = json.loads(test["data"])
streaming = data.get("stream", False)
response = CodegateTestRunner.call_codegate(url, test_headers, data)
expected_response = test["expected"]
provider = test["provider"]

response = self.call_codegate(url, test_headers, data, provider)
if not response:
logger.error(f"Test {test_name} failed: No response received")
return

# Debug response info
logger.debug(f"Response status: {response.status_code}")
logger.debug(f"Response headers: {dict(response.headers)}")

try:
parsed_response = CodegateTestRunner.parse_response_message(
response, streaming=streaming
)
similarity = await self.calculate_string_similarity(parsed_response, expected_response)
if similarity < 0.8:
logger.error(f"Test {test_name} failed")
logger.error(f"Similarity: {similarity}")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response: {expected_response}")
else:
logger.info(f"Test {test['name']} passed")
parsed_response = self.parse_response_message(response, streaming=streaming)

# Load appropriate checks for this test
checks = CheckLoader.load(test)

# Run all checks
passed = True
for check in checks:
passed_check = await check.run_check(parsed_response, test)
if not passed_check:
passed = False
logger.info(f"Test {test_name} passed" if passed else f"Test {test_name} failed")

except Exception as e:
logger.exception("Could not parse response: %s", e)

async def run_tests(self, testcases_file):
async def run_tests(
self,
testcases_file: str,
providers: Optional[list[str]] = None,
test_names: Optional[list[str]] = None,
) -> None:
with open(testcases_file, "r") as f:
tests = yaml.safe_load(f)

headers = tests["headers"]
for _, header_val in headers.items():
if header_val is None:
continue
for key, val in header_val.items():
header_val[key] = CodegateTestRunner.replace_env_variables(val, os.environ)
testcases = tests["testcases"]

test_count = len(tests["testcases"])
if providers or test_names:
filtered_testcases = {}

logger.info(f"Running {test_count} tests")
for _, test_data in tests["testcases"].items():
for test_id, test_data in testcases.items():
if providers:
if test_data.get("provider", "").lower() not in [p.lower() for p in providers]:
continue

if test_names:
if test_data.get("name", "").lower() not in [t.lower() for t in test_names]:
continue

filtered_testcases[test_id] = test_data

testcases = filtered_testcases

if not testcases:
filter_msg = []
if providers:
filter_msg.append(f"providers: {', '.join(providers)}")
if test_names:
filter_msg.append(f"test names: {', '.join(test_names)}")
logger.warning(f"No tests found for {' and '.join(filter_msg)}")
return

test_count = len(testcases)
filter_msg = []
if providers:
filter_msg.append(f"providers: {', '.join(providers)}")
if test_names:
filter_msg.append(f"test names: {', '.join(test_names)}")

logger.info(
f"Running {test_count} tests"
+ (f" for {' and '.join(filter_msg)}" if filter_msg else "")
)

for test_id, test_data in testcases.items():
test_headers = headers.get(test_data["provider"], {})
test_headers = {
k: self.replace_env_variables(v, os.environ) for k, v in test_headers.items()
}
await self.run_test(test_data, test_headers)


async def main():
load_dotenv(find_dotenv())
test_runner = CodegateTestRunner()
await test_runner.run_tests("./tests/integration/testcases.yaml")

# Get providers and test names from environment variables
providers_env = os.environ.get("CODEGATE_PROVIDERS")
test_names_env = os.environ.get("CODEGATE_TEST_NAMES")

providers = None
if providers_env:
providers = [p.strip() for p in providers_env.split(",") if p.strip()]

test_names = None
if test_names_env:
test_names = [t.strip() for t in test_names_env.split(",") if t.strip()]

await test_runner.run_tests(
"./tests/integration/testcases.yaml", providers=providers, test_names=test_names
)


if __name__ == "__main__":
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/requesters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Optional

import requests
import structlog

logger = structlog.get_logger("codegate")


class BaseRequester(ABC):
@abstractmethod
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]:
pass


class StandardRequester(BaseRequester):
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]:
# Ensure Content-Type is always set correctly
headers["Content-Type"] = "application/json"

# Explicitly serialize to JSON string
json_data = json.dumps(data)

return requests.post(
url, headers=headers, data=json_data # Use data instead of json parameter
)


class CopilotRequester(BaseRequester):
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]:
# Ensure Content-Type is always set correctly
headers["Content-Type"] = "application/json"

# Explicitly serialize to JSON string
json_data = json.dumps(data)

return requests.post(
url,
data=json_data, # Use data instead of json parameter
headers=headers,
proxies={"https": "https://localhost:8990", "http": "http://localhost:8990"},
verify=os.environ.get("CA_CERT_FILE"),
stream=True,
)


class RequesterFactory:
@staticmethod
def create_requester(provider: str) -> BaseRequester:
if provider.lower() == "copilot":
return CopilotRequester()
return StandardRequester()
Loading

0 comments on commit ff4a3a7

Please sign in to comment.