This commit is contained in:
yumoqing 2025-06-14 20:54:27 +08:00
parent 221f063d6b
commit ddedfdd95a
5 changed files with 15 additions and 12 deletions

View File

@ -10,11 +10,14 @@ from appPublic.uniqueID import getID
model_pathMap = { model_pathMap = {
} }
def llm_register(model_path, Klass): def llm_register(model_key, Klass):
model_pathMap[model_path] = Klass model_pathMap[model_key] = Klass
def get_llm_class(model_path): def get_llm_class(model_path):
return model_pathMap.get(model_path) for k,klass in model_pathMap.items():
if len(model_path.split(k)) > 1:
return klass
return None
class BaseChatLLM: class BaseChatLLM:
def get_session_key(self): def get_session_key(self):
@ -51,7 +54,7 @@ class BaseChatLLM:
yield { yield {
"id":id, "id":id,
"object":"chat.completion.chunk", "object":"chat.completion.chunk",
"created":time.time(), "created":time(),
"model":self.model_id, "model":self.model_id,
"choices":[ "choices":[
{ {
@ -59,8 +62,8 @@ class BaseChatLLM:
"delta":{ "delta":{
"content":txt "content":txt
}, },
"logprobs":null, "logprobs":None,
"finish_reason":null "finish_reason":None
} }
] ]
} }
@ -71,7 +74,7 @@ class BaseChatLLM:
yield { yield {
"id":id, "id":id,
"object":"chat.completion.chunk", "object":"chat.completion.chunk",
"created":time.time(), "created":time(),
"model":self.model_id, "model":self.model_id,
"response_time": t2 - t1, "response_time": t2 - t1,
"finish_time": t3 - t1, "finish_time": t3 - t1,
@ -82,7 +85,7 @@ class BaseChatLLM:
"delta":{ "delta":{
"content":"" "content":""
}, },
"logprobs":null, "logprobs":None,
"finish_reason":"stop" "finish_reason":"stop"
} }
] ]

View File

@ -55,5 +55,5 @@ class DevstralLLM(T2TChatLLM):
'input_ids': torch.tensor([tokenized.tokens]) 'input_ids': torch.tensor([tokenized.tokens])
} }
llm_register('/share/models/mistralai/Devstral-Small-2505', DevstralLLM) llm_register('mistralai/Devstral', DevstralLLM)

View File

@ -21,7 +21,7 @@ class Gemma3LLM(MMChatLLM):
self.messages = [] self.messages = []
self.model_id = model_id self.model_id = model_id
llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM) llm_register("google/gemma-3", Gemma3LLM)
if __name__ == '__main__': if __name__ == '__main__':
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')

View File

@ -29,7 +29,7 @@ class MedgemmaLLM(MMChatLLM):
).to(self.model.device, dtype=torch.bfloat16) ).to(self.model.device, dtype=torch.bfloat16)
return inputs return inputs
llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM) llm_register("google/medgemma", MedgemmaLLM)
if __name__ == '__main__': if __name__ == '__main__':
med = MedgemmaLLM('/share/models/google/medgemma-4b-it') med = MedgemmaLLM('/share/models/google/medgemma-4b-it')

View File

@ -39,7 +39,7 @@ class Qwen3LLM(T2TChatLLM):
) )
return self.tokenizer([text], return_tensors="pt").to(self.model.device) return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM) llm_register("Qwen/Qwen3", Qwen3LLM)
if __name__ == '__main__': if __name__ == '__main__':
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B') q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')