import copy
import json
import os
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
client = OpenAI()
def traverse_directory(path):
for dirpath, dirnames, filenames in os.walk(path):
for file in filenames:
print(os.path.join(dirpath, file))
def code2prompt(filename):
"""Translate source code into prompt.
Returns:
string: Prompt that can be used by GPT.
"""
"""
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
# 任务描述
instruction = """
你的任务是解释代码。我们只会提供 C/C++ 语言、Python 语言、Java 语言让你分析。
并且在这个项目中,有可能多种语言是混用的。
"""
# 输出格式定义各种约束
output_format = """
以 JSON 格式输出。
1. 在 json 中 main 函数始终作为一级目录;
2. 如果有多个 main 函数,应该有多个一级目录;
3. 被 main 函数调用的其他函数是二级目录,以此类推,直到在当前目录下找不到更深层的调用关系。
只输出包含用户提及的字段,不要猜测任何用户未直接提及的字段。
DO NOT OUTPUT NULL-VALUED FILED! 确保输出能被 json.loads 加载。
"""
# 例子可以让输出更稳定
# 多轮对话的上下文就是调用链中的上下级关系,在 JSON 格式中,上文在 prev 字段中保存,下文在 next 字段中保存
examples = """
只有 1 个 main 函数的项目:
客服:有什么可以帮助你吗?
用户:在这个项目中包含了几个 main 函数?
客服:一个。
用户:请分析 main 文件中的代码。
客服:这是 main 函数的功能描述。
用户:在函数 main 中调用了哪些函数?
客服:函数 main 调用了 func1 和 func2。
用户:请分析 func1 和 func2 的功能。
客服:这是函数 func1 的功能描述,这是函数 func2 的功能描述。
用户:请根据你之前的回答,汇总一下,生成 JSON 格式的输出。
客服:
[
{
"main": {
"desc": "这是函数 main 的功能描述",
"prev": null,
"next": {
"func1": {
"desc": "这是函数 func1 的功能描述",
"prev": "main",
"next": {
"func2": {
"desc": "这是函数 func2 的功能描述",
"prev": "func1",
"next": null
},
},
},
"func3": {
"desc": "这是函数 func3 的功能描述",
"prev": "main",
"next": null
},
}
}
}
]
含有 2 个 main 函数的项目:
客服:有什么可以帮助你吗?
用户:在这个项目中包含了几个 main 函数?
客服:两个。
用户:请分析 main1 文件中的代码。
客服:这是函数 main1 的功能描述。
用户:请分析 main2 文件中的代码。
客服:这是函数 main2 的功能描述。
用户:请根据你之前的回答,汇总一下,生成 JSON 格式的输出。
客服:
[
{
"main1": {
"desc": "这是函数 main1 的功能描述",
"prev": null,
"next": null
}
},
{
"main2": {
"desc": "这是函数 main2 的功能描述",
"prev": null,
"next": null
}
}
]
"""
# 需要解析的文本
input_text = f"""
{content}
"""
prompt = f"""
{instruction}\n\n{output_format}\n\n例如:\n{examples}\n\n用户输入:\n{input_text}
"""
return prompt
class NLU:
"""自然语言理解(Nature Language Understanding, NLU),调用 GPT 获得反馈。"""
def __init__(self, filename):
self.prompt_template = code2prompt(filename)
def _get_completion(self, prompt, model="gpt-3.5-turbo"):
messages = [{"role": "user", "content": prompt}]
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0, # 模型输出的随机性,0 表示随机性最小
)
semantics = json.loads(response.choices[0].message.content)
return {k: v for k, v in semantics.items() if v}
def parse(self, user_input):
prompt = self.prompt_template.replace("__INPUT__", user_input)
return self._get_completion(prompt)
class DST:
"""对话状态跟踪(Dialogue State Tracking),实现多轮对话。"""
def __init__(self):
pass
def update(self, state, nlu_semantics):
if "name" in nlu_semantics:
state.clear()
if "sort" in nlu_semantics:
slot = nlu_semantics["sort"]["value"]
if slot in state and state[slot]["operator"] == "==":
del state[slot]
for k, v in nlu_semantics.items():
state[k] = v
return state
class MockedDB:
"""假数据库,为模型提供持久化能力。"""
def __init__(self):
self.data = [
{
"main": {
"desc": "这是函数 main 的功能描述。",
"prev": "null",
"next": {
"func1": {
"desc": "这是函数 func1 的功能描述。",
"prev": "main",
"next": "null",
},
},
},
}
]
def retrieve(self, **kwargs):
records = []
for r in self.data:
select = True
if r["requirement"]:
if "status" not in kwargs or kwargs["status"] != r["requirement"]:
continue
for k, v in kwargs.items():
if k == "sort":
continue
if k == "data" and v["value"] == "无上限":
if r[k] != 1000:
select = False
break
if "operator" in v:
if not eval(str(r[k]) + v["operator"] + str(v["value"])):
select = False
break
elif str(r[k]) != str(v):
select = False
break
if select:
records.append(r)
if len(records) <= 1:
return records
key = "price"
reverse = False
if "sort" in kwargs:
key = kwargs["sort"]["value"]
reverse = kwargs["sort"]["ordering"] == "descend"
return sorted(records, key=lambda x: x[key], reverse=reverse)
class DialogManager:
def __init__(self, prompt_templates):
self.state = {}
self.session = [
{
"role": "system",
"content": "你的任务是解释代码。我们只会提供 C/C++ 语言、Python 语言、Java 语言让你分析。",
}
]
self.nlu = NLU()
self.dst = DST()
self.db = MockedDB()
self.prompt_templates = prompt_templates
def _wrap(self, user_input, records):
if records:
prompt = self.prompt_templates["recommand"].replace("__INPUT__", user_input)
r = records[0]
for k, v in r.items():
prompt = prompt.replace(f"__{k.upper()}__", str(v))
else:
prompt = self.prompt_templates["not_found"].replace("__INPUT__", user_input)
for k, v in self.state.items():
if "operator" in v:
prompt = prompt.replace(
f"__{k.upper()}__", v["operator"] + str(v["value"])
)
else:
prompt = prompt.replace(f"__{k.upper()}__", str(v))
return prompt
def _call_chatgpt(self, prompt, model="gpt-3.5-turbo"):
session = copy.deepcopy(self.session)
session.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
model=model,
messages=session,
temperature=0,
)
return response.choices[0].message.content
def run(self, user_input):
# 调用NLU获得语义解析
semantics = self.nlu.parse(user_input)
print("===semantics===")
print(semantics)
# 调用DST更新多轮状态
self.state = self.dst.update(self.state, semantics)
print("===state===")
print(self.state)
# 根据状态检索DB,获得满足条件的候选
records = self.db.retrieve(**self.state)
# 拼装prompt调用chatgpt
prompt_for_chatgpt = self._wrap(user_input, records)
print("===gpt-prompt===")
print(prompt_for_chatgpt)
# 调用chatgpt获得回复
response = self._call_chatgpt(prompt_for_chatgpt)
# 将当前用户输入和系统回复维护入chatgpt的session
self.session.append({"role": "user", "content": user_input})
self.session.append({"role": "assistant", "content": response})
return response
if __name__ == "__main__":
curr_dir = os.getcwd()
traverse_directory(curr_dir)
code2prompt(curr_dir + "/analyzer/main.py")