Skip to content

Commit cb98657

Browse files
committed
Improve gpt performance
1 parent f8ad220 commit cb98657

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

mathtranslate/chatgpt.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import openai
66
import sys
77
import time
8+
import re
89

910
class GPTTranslator:
1011
def __init__(self):
@@ -15,23 +16,43 @@ def __init__(self):
1516

1617

1718
def format_prompt(self, text, language_to, language_from):
18-
PROMPT_PROTOTYPE = "As an academic expert with specialized knowledge in various fields, please provide a proficient and precise translation translation from {} to {} of the academic text enclosed in 🔤. It is crucial to maintaining the original phrase or sentence and ensure accuracy while utilizing the appropriate language. Please provide only the translated result without any additional explanation and remove 🔤. The text is as follows: 🔤 {} 🔤 "
19+
PROMPT_PROTOTYPE = 'As an academic expert with specialized knowledge in various fields, please provide a proficient and precise translation translation from {} to {} of the academic text enclosed in 🔤. It is crucial to maintaining the original phrase or sentence and ensure accuracy while utilizing the appropriate language. Please provide only the translated result without any additional explanation and remove 🔤. Do not modify or delete any word contains "/XMATHX_" such as /XMATHX_0, /XMATHX_1, /XMATHX_3_4. The text is as follows: 🔤 {} 🔤 '
1920
#prompt prototype changed from https://github.com/windingwind/zotero-pdf-translate
20-
return PROMPT_PROTOTYPE.format(language_from,language_to,text)
21+
SYSTEM_PROMPT_PROTOTYPE = 'You are an academic translator with specialized knowledge in various fields, please provide a proficient and precise translation translation from {} to {} of the academic text enclosed in 🔤.Do not modify or delete any word contains "/XMATHX_" such as /XMATHX_0, /XMATHX_1, /XMATHX_3_4.'
22+
return {'system':SYSTEM_PROMPT_PROTOTYPE.format(language_from,language_to),'user':PROMPT_PROTOTYPE.format(language_from,language_to,text)}
2123

2224
def get_server_errormsg(self,error):
2325
try :
2426
return error.response.json()['error']['message']
2527
except Exception :
2628
return error.message
2729

30+
def find_all_mathmask(self,text):
31+
mask_pattern=re.compile(r'/XMATHX(_[0-9])+')
32+
masks = set([i.group() for i in re.finditer(pattern=mask_pattern,string=text)])
33+
return masks
34+
35+
def is_gpt_output_valid(self,masks,text_translated):
36+
masks_translated = self.find_all_mathmask(text_translated)
37+
return (masks_translated==masks)
38+
39+
def is_text_all_mask(self,masks,text):
40+
for mask in masks:
41+
text = text.replace(mask,'')
42+
return text.isspace()
43+
44+
2845
def call_openai_api(self,prompt):
2946
messages= [{
47+
"role":"system",
48+
"content": prompt['system']
49+
},
50+
{
3051
"role": "user",
31-
"content": prompt
52+
"content": prompt['user']
3253
}]
3354
try:
34-
return self.client.chat.completions.create(model=self.model,temperature=0.8,messages=messages)
55+
return self.client.chat.completions.create(model=self.model,temperature=1,messages=messages)
3556
except openai.RateLimitError as e:
3657
print('API rate limit exceeded, retry after 15s')
3758
time.sleep(15)
@@ -50,10 +71,13 @@ def call_openai_api(self,prompt):
5071

5172

5273
def translate(self, text, language_to, language_from):
74+
masks = self.find_all_mathmask(text)
75+
if self.is_text_all_mask(masks,text):
76+
return text
5377
while True:
5478
result = self.call_openai_api(self.format_prompt(text, language_to, language_from))
55-
content_translated = result.choices[0].message.content
56-
if '🔤' not in content_translated:
79+
content_translated = result.choices[0].message.content.replace('🔤','')
80+
if self.is_gpt_output_valid(masks,content_translated):
5781
return content_translated
5882

5983

mathtranslate/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Config:
2929
openai_api_endpoint_default = 'https://api.openai.com'
3030
openai_api_key_default = None
3131

32-
math_code = 'XMATHX'
32+
math_code = '/XMATHX' #better for gpt to understand
3333
log_file = f'{app_dir}/translate_log'
3434
raw_mularg_command_list = [('textcolor', 2, (1, ))]
3535
mularg_command_list = [('textcolor', 2, (1, ))]

mathtranslate/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def process_options(options):
116116
config.set_variable(config.openai_api_endpoint_path, config.openai_api_endpoint_default)
117117
print('OpenAI api key (something like sk-xxx...):')
118118
config.set_variable(config.openai_api_key_path, config.openai_api_key_default)
119-
print('ChatGPT model name: (leave empty for default {}'.format(config.openai_model_name_default))
119+
print('ChatGPT model name: (leave empty for default {})'.format(config.openai_model_name_default))
120120
config.set_variable(config.openai_model_name_path,config.openai_model_name_default)
121121
print('saved!')
122122
config.load()
@@ -169,8 +169,6 @@ def process_options(options):
169169
print('Please setup api info for openAI api first by')
170170
print('translate_tex --setgpt')
171171
sys.exit()
172-
options.threads = 1
173-
print('disable mult-threading for chatGPT api')
174172

175173
if options.threads < 0:
176174
print('threads must be a non-zero integer number (>=0 where 0 means auto), set to auto')

0 commit comments

Comments
 (0)