Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jan 25, 2025
1 parent 5b192dc commit e09f526
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
16 changes: 16 additions & 0 deletions wren-ai-service/src/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ async def execute_sql(
...


def clean_generation_result(result: str) -> str:
def _normalize_whitespace(s: str) -> str:
return re.sub(r"\s+", " ", s).strip()

return (
_normalize_whitespace(result)
.replace("\\n", " ")
.replace("```sql", "")
.replace("```json", "")
.replace('"""', "")
.replace("'''", "")
.replace("```", "")
.replace(";", "")
)


def remove_limit_statement(sql: str) -> str:
pattern = r"\s*LIMIT\s+\d+(\s*;?\s*--.*|\s*;?\s*)$"
modified_sql = re.sub(pattern, "", sql, flags=re.IGNORECASE)
Expand Down
9 changes: 8 additions & 1 deletion wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills.
Now you are given syntactically incorrect ANSI SQL query and related error message.
With given database schema, please think step by step to generate the syntactically correct ANSI SQL query without changing semantics.
With given database schema, please generate the syntactically correct ANSI SQL query without changing original semantics.
{TEXT_TO_SQL_RULES}
### FINAL ANSWER FORMAT ###
The final answer must be a corrected SQL query in JSON format:
{{
"sql": <CORRECTED_SQL_QUERY_STRING>
}}
"""

sql_correction_user_prompt_template = """
Expand Down
7 changes: 7 additions & 0 deletions wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
- Columns to be adjusted must belong to the given database schema; if no such column exists, keep sql empty string
- You can add/delete/modify columns, add/delete/modify keywords such as DISTINCT or apply aggregate functions on columns
- Consider current time from user input if user's adjustment request is related to date and time
### FINAL ANSWER FORMAT ###
The final answer must be a SQL query in JSON format:
{
"sql": <SQL_QUERY_STRING>
}
"""

sql_expansion_user_prompt_template = """
Expand Down
18 changes: 15 additions & 3 deletions wren-ai-service/src/pipelines/generation/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from src.core.engine import (
Engine,
add_quotes,
clean_generation_result,
)
from src.web.v1.services import Configuration

Expand All @@ -29,7 +30,7 @@ async def run(
replies: List[str],
project_id: str | None = None,
) -> Dict[str, Any]:
cleaned_generation_result = orjson.loads(replies[0])
cleaned_generation_result = orjson.loads(clean_generation_result(replies[0]))

steps = cleaned_generation_result.get("steps", [])
if not steps:
Expand Down Expand Up @@ -119,12 +120,16 @@ async def run(
for reply in replies:
try:
cleaned_generation_result.append(
orjson.loads(reply["replies"][0])["sql"]
orjson.loads(clean_generation_result(reply["replies"][0]))[
"sql"
]
)
except Exception as e:
logger.exception(f"Error in SQLGenPostProcessor: {e}")
else:
cleaned_generation_result = orjson.loads(replies[0])["sql"]
cleaned_generation_result = orjson.loads(
clean_generation_result(replies[0])
)["sql"]

if isinstance(cleaned_generation_result, str):
cleaned_generation_result = [cleaned_generation_result]
Expand Down Expand Up @@ -336,6 +341,13 @@ async def _task(sql: str):
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
{TEXT_TO_SQL_RULES}
### FINAL ANSWER FORMAT ###
The final answer must be a SQL query in JSON format:
{{
"sql": <SQL_QUERY_STRING>
}}
"""

calculated_field_instructions = """
Expand Down

0 comments on commit e09f526

Please sign in to comment.