aboutsummaryrefslogtreecommitdiff
path: root/examples/server/api_like_OAI.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/api_like_OAI.py')
-rwxr-xr-xexamples/server/api_like_OAI.py219
1 files changed, 219 insertions, 0 deletions
diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py
new file mode 100755
index 0000000..aa325a0
--- /dev/null
+++ b/examples/server/api_like_OAI.py
@@ -0,0 +1,219 @@
+import argparse
+from flask import Flask, jsonify, request, Response
+import urllib.parse
+import requests
+import time
+import json
+
+
+app = Flask(__name__)
+
+parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
+parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
+parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
+parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
+parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
+parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
+parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
+parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
+parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1')
+parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081)
+
+args = parser.parse_args()
+
+def is_present(json, key):
+ try:
+ buf = json[key]
+ except KeyError:
+ return False
+ return True
+
+
+
+#convert chat to prompt
+def convert_chat(messages):
+ prompt = "" + args.chat_prompt.replace("\\n", "\n")
+
+ system_n = args.system_name.replace("\\n", "\n")
+ user_n = args.user_name.replace("\\n", "\n")
+ ai_n = args.ai_name.replace("\\n", "\n")
+ stop = args.stop.replace("\\n", "\n")
+
+
+ for line in messages:
+ if (line["role"] == "system"):
+ prompt += f"{system_n}{line['content']}"
+ if (line["role"] == "user"):
+ prompt += f"{user_n}{line['content']}"
+ if (line["role"] == "assistant"):
+ prompt += f"{ai_n}{line['content']}{stop}"
+ prompt += ai_n.rstrip()
+
+ return prompt
+
+def make_postData(body, chat=False, stream=False):
+ postData = {}
+ if (chat):
+ postData["prompt"] = convert_chat(body["messages"])
+ else:
+ postData["prompt"] = body["prompt"]
+ if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
+ if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
+ if(is_present(body, "top_p")): postData["top_p"] = body["top_p"]
+ if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"]
+ if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"]
+ if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"]
+ if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"]
+ if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"]
+ if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"]
+ if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
+ if(is_present(body, "seed")): postData["seed"] = body["seed"]
+ if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
+ if (args.stop != ""):
+ postData["stop"] = [args.stop]
+ else:
+ postData["stop"] = []
+ if(is_present(body, "stop")): postData["stop"] += body["stop"]
+ postData["n_keep"] = -1
+ postData["stream"] = stream
+
+ return postData
+
+def make_resData(data, chat=False, promptToken=[]):
+ resData = {
+ "id": "chatcmpl" if (chat) else "cmpl",
+ "object": "chat.completion" if (chat) else "text_completion",
+ "created": int(time.time()),
+ "truncated": data["truncated"],
+ "model": "LLaMA_CPP",
+ "usage": {
+ "prompt_tokens": data["tokens_evaluated"],
+ "completion_tokens": data["tokens_predicted"],
+ "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"]
+ }
+ }
+ if (len(promptToken) != 0):
+ resData["promptToken"] = promptToken
+ if (chat):
+ #only one choice is supported
+ resData["choices"] = [{
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": data["content"],
+ },
+ "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+ }]
+ else:
+ #only one choice is supported
+ resData["choices"] = [{
+ "text": data["content"],
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+ }]
+ return resData
+
+def make_resData_stream(data, chat=False, time_now = 0, start=False):
+ resData = {
+ "id": "chatcmpl" if (chat) else "cmpl",
+ "object": "chat.completion.chunk" if (chat) else "text_completion.chunk",
+ "created": time_now,
+ "model": "LLaMA_CPP",
+ "choices": [
+ {
+ "finish_reason": None,
+ "index": 0
+ }
+ ]
+ }
+ if (chat):
+ if (start):
+ resData["choices"][0]["delta"] = {
+ "role": "assistant"
+ }
+ else:
+ resData["choices"][0]["delta"] = {
+ "content": data["content"]
+ }
+ if (data["stop"]):
+ resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+ else:
+ resData["choices"][0]["text"] = data["content"]
+ if (data["stop"]):
+ resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+
+ return resData
+
+
+@app.route('/chat/completions', methods=['POST'])
+@app.route('/v1/chat/completions', methods=['POST'])
+def chat_completions():
+ if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
+ return Response(status=403)
+ body = request.get_json()
+ stream = False
+ tokenize = False
+ if(is_present(body, "stream")): stream = body["stream"]
+ if(is_present(body, "tokenize")): tokenize = body["tokenize"]
+ postData = make_postData(body, chat=True, stream=stream)
+
+ promptToken = []
+ if (tokenize):
+ tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
+ promptToken = tokenData["tokens"]
+
+ if (not stream):
+ data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
+ print(data.json())
+ resData = make_resData(data.json(), chat=True, promptToken=promptToken)
+ return jsonify(resData)
+ else:
+ def generate():
+ data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
+ time_now = int(time.time())
+ resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
+ yield 'data: {}\n'.format(json.dumps(resData))
+ for line in data.iter_lines():
+ if line:
+ decoded_line = line.decode('utf-8')
+ resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
+ yield 'data: {}\n'.format(json.dumps(resData))
+ return Response(generate(), mimetype='text/event-stream')
+
+
+@app.route('/completions', methods=['POST'])
+@app.route('/v1/completions', methods=['POST'])
+def completion():
+ if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
+ return Response(status=403)
+ body = request.get_json()
+ stream = False
+ tokenize = False
+ if(is_present(body, "stream")): stream = body["stream"]
+ if(is_present(body, "tokenize")): tokenize = body["tokenize"]
+ postData = make_postData(body, chat=False, stream=stream)
+
+ promptToken = []
+ if (tokenize):
+ tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
+ promptToken = tokenData["tokens"]
+
+ if (not stream):
+ data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
+ print(data.json())
+ resData = make_resData(data.json(), chat=False, promptToken=promptToken)
+ return jsonify(resData)
+ else:
+ def generate():
+ data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
+ time_now = int(time.time())
+ for line in data.iter_lines():
+ if line:
+ decoded_line = line.decode('utf-8')
+ resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
+ yield 'data: {}\n'.format(json.dumps(resData))
+ return Response(generate(), mimetype='text/event-stream')
+
+if __name__ == '__main__':
+ app.run(args.host, port=args.port)