2023-03-05 15:39:13 +08:00
""" OpenAI 接口底层封装
目前使用的对话接口有 :
ChatCompletion - gpt - 3.5 - turbo 等模型
Completion - text - davinci - 003 等模型
此模块封装此两个接口的请求实现 , 为上层提供统一的调用方式
"""
2023-03-03 14:12:53 +08:00
import openai , logging , threading , asyncio
2023-03-05 13:52:43 +08:00
import openai . error as aiE
2023-07-31 11:59:22 +08:00
import tiktoken
2023-03-02 15:31:12 +08:00
2023-07-28 19:03:02 +08:00
from pkg . openai . api . model import RequestBase
from pkg . openai . api . completion import CompletionRequest
from pkg . openai . api . chat_completion import ChatCompletionRequest
2023-03-02 17:57:39 +08:00
COMPLETION_MODELS = {
' text-davinci-003 ' ,
' text-davinci-002 ' ,
' code-davinci-002 ' ,
' code-cushman-001 ' ,
' text-curie-001 ' ,
' text-babbage-001 ' ,
' text-ada-001 ' ,
2023-03-02 15:31:12 +08:00
}
2023-01-01 22:52:27 +08:00
2023-03-02 17:57:39 +08:00
CHAT_COMPLETION_MODELS = {
' gpt-3.5-turbo ' ,
2023-06-16 19:35:26 +08:00
' gpt-3.5-turbo-16k ' ,
' gpt-3.5-turbo-0613 ' ,
' gpt-3.5-turbo-16k-0613 ' ,
# 'gpt-3.5-turbo-0301',
2023-03-18 12:38:48 +08:00
' gpt-4 ' ,
2023-06-16 19:35:26 +08:00
' gpt-4-0613 ' ,
2023-03-18 12:38:48 +08:00
' gpt-4-32k ' ,
2023-06-16 19:35:26 +08:00
' gpt-4-32k-0613 '
2023-01-01 22:52:27 +08:00
}
EDIT_MODELS = {
}
IMAGE_MODELS = {
}
2023-03-02 16:41:03 +08:00
2023-07-28 19:03:02 +08:00
def select_request_cls ( model_name : str , messages : list , args : dict ) - > RequestBase :
2023-03-02 15:31:12 +08:00
if model_name in CHAT_COMPLETION_MODELS :
2023-07-28 19:03:02 +08:00
return ChatCompletionRequest ( model_name , messages , * * args )
2023-03-02 15:31:12 +08:00
elif model_name in COMPLETION_MODELS :
2023-07-28 19:03:02 +08:00
return CompletionRequest ( model_name , messages , * * args )
2023-07-31 11:59:22 +08:00
raise ValueError ( " 不支持模型[ {} ],请检查配置文件 " . format ( model_name ) )
def count_chat_completion_tokens ( messages : list , model : str ) - > int :
""" Return the number of tokens used by a list of messages. """
try :
encoding = tiktoken . encoding_for_model ( model )
except KeyError :
print ( " Warning: model not found. Using cl100k_base encoding. " )
encoding = tiktoken . get_encoding ( " cl100k_base " )
if model in {
" gpt-3.5-turbo-0613 " ,
" gpt-3.5-turbo-16k-0613 " ,
" gpt-4-0314 " ,
" gpt-4-32k-0314 " ,
" gpt-4-0613 " ,
" gpt-4-32k-0613 " ,
} :
tokens_per_message = 3
tokens_per_name = 1
elif model == " gpt-3.5-turbo-0301 " :
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = - 1 # if there's a name, the role is omitted
elif " gpt-3.5-turbo " in model :
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return count_chat_completion_tokens ( messages , model = " gpt-3.5-turbo-0613 " )
elif " gpt-4 " in model :
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_chat_completion_tokens ( messages , model = " gpt-4-0613 " )
else :
raise NotImplementedError (
f """ count_chat_completion_tokens() is not implemented for model { model } . See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens. """
)
num_tokens = 0
for message in messages :
num_tokens + = tokens_per_message
for key , value in message . items ( ) :
num_tokens + = len ( encoding . encode ( value ) )
if key == " name " :
num_tokens + = tokens_per_name
num_tokens + = 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def count_completion_tokens ( messages : list , model : str ) - > int :
try :
encoding = tiktoken . encoding_for_model ( model )
except KeyError :
print ( " Warning: model not found. Using cl100k_base encoding. " )
encoding = tiktoken . get_encoding ( " cl100k_base " )
text = " "
for message in messages :
text + = message [ ' role ' ] + message [ ' content ' ] + " \n "
text + = " assistant: "
return len ( encoding . encode ( text ) )
def count_tokens ( messages : list , model : str ) :
if model in CHAT_COMPLETION_MODELS :
return count_chat_completion_tokens ( messages , model )
elif model in COMPLETION_MODELS :
return count_completion_tokens ( messages , model )
raise ValueError ( " 不支持模型[ {} ],请检查配置文件 " . format ( model ) )