-
Notifications
You must be signed in to change notification settings - Fork 283
Expand file tree
/
Copy pathmain.py
More file actions
145 lines (104 loc) · 4.34 KB
/
main.py
File metadata and controls
145 lines (104 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from groq import Groq
import json
import duckdb
import sqlparse
def chat_with_groq(client, prompt, model, response_format):
"""
This function sends a prompt to the Groq API and retrieves the AI's response.
Parameters:
client (Groq): The Groq API client.
prompt (str): The prompt to send to the AI.
model (str): The AI model to use for the response.
response_format (dict): The format of the response.
If response_format is a dictionary with {"type": "json_object"}, it configures JSON mode.
Returns:
str: The content of the AI's response.
"""
completion = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": prompt
}
],
response_format=response_format
)
return completion.choices[0].message.content
def execute_duckdb_query(query):
"""
This function executes a SQL query on a DuckDB database and returns the result.
Parameters:
query (str): The SQL query to execute.
Returns:
DataFrame: The result of the query as a pandas DataFrame.
"""
original_cwd = os.getcwd()
os.chdir('data')
try:
conn = duckdb.connect(database=':memory:', read_only=False)
query_result = conn.execute(query).fetchdf().reset_index(drop=True)
finally:
os.chdir(original_cwd)
return query_result
def get_summarization(client, user_question, df, model):
"""
This function generates a summarization prompt based on the user's question and the resulting data.
It then sends this summarization prompt to the Groq API and retrieves the AI's response.
Parameters:
client (Groqcloud): The Groq API client.
user_question (str): The user's question.
df (DataFrame): The DataFrame resulting from the SQL query.
model (str): The AI model to use for the response.
Returns:
str: The content of the AI's response to the summarization prompt.
"""
prompt = '''
A user asked the following question pertaining to local database tables:
{user_question}
To answer the question, a dataframe was returned:
Dataframe:
{df}
In a few sentences, summarize the data in the table as it pertains to the original user question. Avoid qualifiers like "based on the data" and do not comment on the structure or metadata of the table itself
'''.format(user_question = user_question, df = df)
# Response format is set to 'None'
return chat_with_groq(client,prompt,model,None)
def main():
"""
The main function of the application. It handles user input, controls the flow of the application,
and initiates a conversation in the command line.
"""
model = "llama3-70b-8192"
# Get the Groq API key and create a Groq client
groq_api_key = os.getenv('GROQ_API_KEY')
client = Groq(
api_key=groq_api_key
)
print("Welcome to the DuckDB Query Generator!")
print("You can ask questions about the data in the 'employees.csv' and 'purchases.csv' files.")
# Load the base prompt
with open('prompts/base_prompt.txt', 'r') as file:
base_prompt = file.read()
while True:
# Get the user's question
user_question = input("Ask a question: ")
if user_question:
# Generate the full prompt for the AI
full_prompt = base_prompt.format(user_question=user_question)
# Get the AI's response. Call with '{"type": "json_object"}' to use JSON mode
llm_response = chat_with_groq(client, full_prompt, model, {"type": "json_object"})
result_json = json.loads(llm_response)
if 'sql' in result_json:
sql_query = result_json['sql']
results_df = execute_duckdb_query(sql_query)
formatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
print("```sql\n" + formatted_sql_query + "\n```")
print(results_df.to_markdown(index=False))
summarization = get_summarization(client,user_question,results_df,model)
print(summarization.replace('$','\\$'))
elif 'error' in result_json:
print("ERROR:", 'Could not generate valid SQL for this question')
print(result_json['error'])
if __name__ == "__main__":
main()