mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 19:57:04 +08:00
116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
|
import sys
|
||
|
import re
|
||
|
import inspect
|
||
|
|
||
|
|
||
|
def get_func_schema(function: callable) -> dict:
|
||
|
"""
|
||
|
Return the data schema of a function.
|
||
|
{
|
||
|
"function": function,
|
||
|
"description": "function description",
|
||
|
"parameters": {
|
||
|
"type": "object",
|
||
|
"properties": {
|
||
|
"parameter_a": {
|
||
|
"type": "str",
|
||
|
"description": "parameter_a description"
|
||
|
},
|
||
|
"parameter_b": {
|
||
|
"type": "int",
|
||
|
"description": "parameter_b description"
|
||
|
},
|
||
|
"parameter_c": {
|
||
|
"type": "str",
|
||
|
"description": "parameter_c description",
|
||
|
"enum": ["a", "b", "c"]
|
||
|
},
|
||
|
},
|
||
|
"required": ["parameter_a", "parameter_b"]
|
||
|
}
|
||
|
}
|
||
|
"""
|
||
|
func_doc = function.__doc__
|
||
|
# Google Style Docstring
|
||
|
if func_doc is None:
|
||
|
raise Exception("Function {} has no docstring.".format(function.__name__))
|
||
|
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
|
||
|
# extract doc of args from docstring
|
||
|
doc_spt = func_doc.split('\n\n')
|
||
|
desc = doc_spt[0]
|
||
|
args = doc_spt[1] if len(doc_spt) > 1 else ""
|
||
|
returns = doc_spt[2] if len(doc_spt) > 2 else ""
|
||
|
|
||
|
# extract args
|
||
|
# delete the first line of args
|
||
|
arg_lines = args.split('\n')[1:]
|
||
|
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
|
||
|
args_doc = {}
|
||
|
for arg_line in arg_lines:
|
||
|
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
|
||
|
if len(doc_tuple) == 0:
|
||
|
continue
|
||
|
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
|
||
|
|
||
|
# extract returns
|
||
|
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
|
||
|
|
||
|
params = enumerate(inspect.signature(function).parameters.values())
|
||
|
parameters = {
|
||
|
"type": "object",
|
||
|
"required": [],
|
||
|
"properties": {},
|
||
|
}
|
||
|
|
||
|
|
||
|
for i, param in params:
|
||
|
|
||
|
# 排除 self, query
|
||
|
if param.name in ['self', 'query']:
|
||
|
continue
|
||
|
|
||
|
param_type = param.annotation.__name__
|
||
|
|
||
|
type_name_mapping = {
|
||
|
"str": "string",
|
||
|
"int": "integer",
|
||
|
"float": "number",
|
||
|
"bool": "boolean",
|
||
|
"list": "array",
|
||
|
"dict": "object",
|
||
|
}
|
||
|
|
||
|
if param_type in type_name_mapping:
|
||
|
param_type = type_name_mapping[param_type]
|
||
|
|
||
|
parameters['properties'][param.name] = {
|
||
|
"type": param_type,
|
||
|
"description": args_doc[param.name],
|
||
|
}
|
||
|
|
||
|
# add schema for array
|
||
|
if param_type == "array":
|
||
|
# extract type of array, the int of list[int]
|
||
|
# use re
|
||
|
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
|
||
|
|
||
|
array_type = 'string'
|
||
|
|
||
|
if len(array_type_tuple) > 0:
|
||
|
array_type = array_type_tuple[0]
|
||
|
|
||
|
if array_type in type_name_mapping:
|
||
|
array_type = type_name_mapping[array_type]
|
||
|
|
||
|
parameters['properties'][param.name]["items"] = {
|
||
|
"type": array_type,
|
||
|
}
|
||
|
|
||
|
if param.default is inspect.Parameter.empty:
|
||
|
parameters["required"].append(param.name)
|
||
|
|
||
|
return {
|
||
|
"function": function,
|
||
|
"description": desc,
|
||
|
"parameters": parameters,
|
||
|
}
|