diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index b0928c4..8786555 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -10,11 +10,14 @@ from appPublic.uniqueID import getID model_pathMap = { } -def llm_register(model_path, Klass): - model_pathMap[model_path] = Klass +def llm_register(model_key, Klass): + model_pathMap[model_key] = Klass def get_llm_class(model_path): - return model_pathMap.get(model_path) + for k,klass in model_pathMap.items(): + if len(model_path.split(k)) > 1: + return klass + return None class BaseChatLLM: def get_session_key(self): @@ -51,7 +54,7 @@ class BaseChatLLM: yield { "id":id, "object":"chat.completion.chunk", - "created":time.time(), + "created":time(), "model":self.model_id, "choices":[ { @@ -59,8 +62,8 @@ class BaseChatLLM: "delta":{ "content":txt }, - "logprobs":null, - "finish_reason":null + "logprobs":None, + "finish_reason":None } ] } @@ -71,7 +74,7 @@ class BaseChatLLM: yield { "id":id, "object":"chat.completion.chunk", - "created":time.time(), + "created":time(), "model":self.model_id, "response_time": t2 - t1, "finish_time": t3 - t1, @@ -82,7 +85,7 @@ class BaseChatLLM: "delta":{ "content":"" }, - "logprobs":null, + "logprobs":None, "finish_reason":"stop" } ] diff --git a/llmengine/devstral.py b/llmengine/devstral.py index e3caba2..6d71931 100644 --- a/llmengine/devstral.py +++ b/llmengine/devstral.py @@ -55,5 +55,5 @@ class DevstralLLM(T2TChatLLM): 'input_ids': torch.tensor([tokenized.tokens]) } -llm_register('/share/models/mistralai/Devstral-Small-2505', DevstralLLM) +llm_register('mistralai/Devstral', DevstralLLM) diff --git a/llmengine/gemma3_it.py b/llmengine/gemma3_it.py index 635ca98..1028b29 100644 --- a/llmengine/gemma3_it.py +++ b/llmengine/gemma3_it.py @@ -21,7 +21,7 @@ class Gemma3LLM(MMChatLLM): self.messages = [] self.model_id = model_id -llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM) +llm_register("google/gemma-3", Gemma3LLM) if __name__ == '__main__': gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') diff --git a/llmengine/medgemma3_it.py b/llmengine/medgemma3_it.py index 56c22f2..db3d73a 100644 --- a/llmengine/medgemma3_it.py +++ b/llmengine/medgemma3_it.py @@ -29,7 +29,7 @@ class MedgemmaLLM(MMChatLLM): ).to(self.model.device, dtype=torch.bfloat16) return inputs -llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM) +llm_register("google/medgemma", MedgemmaLLM) if __name__ == '__main__': med = MedgemmaLLM('/share/models/google/medgemma-4b-it') diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index 72ffe7a..c536d4c 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -39,7 +39,7 @@ class Qwen3LLM(T2TChatLLM): ) return self.tokenizer([text], return_tensors="pt").to(self.model.device) -llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM) +llm_register("Qwen/Qwen3", Qwen3LLM) if __name__ == '__main__': q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')