bugfix
This commit is contained in:
parent
221f063d6b
commit
ddedfdd95a
@ -10,11 +10,14 @@ from appPublic.uniqueID import getID
|
|||||||
|
|
||||||
model_pathMap = {
|
model_pathMap = {
|
||||||
}
|
}
|
||||||
def llm_register(model_path, Klass):
|
def llm_register(model_key, Klass):
|
||||||
model_pathMap[model_path] = Klass
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
def get_llm_class(model_path):
|
def get_llm_class(model_path):
|
||||||
return model_pathMap.get(model_path)
|
for k,klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
return None
|
||||||
|
|
||||||
class BaseChatLLM:
|
class BaseChatLLM:
|
||||||
def get_session_key(self):
|
def get_session_key(self):
|
||||||
@ -51,7 +54,7 @@ class BaseChatLLM:
|
|||||||
yield {
|
yield {
|
||||||
"id":id,
|
"id":id,
|
||||||
"object":"chat.completion.chunk",
|
"object":"chat.completion.chunk",
|
||||||
"created":time.time(),
|
"created":time(),
|
||||||
"model":self.model_id,
|
"model":self.model_id,
|
||||||
"choices":[
|
"choices":[
|
||||||
{
|
{
|
||||||
@ -59,8 +62,8 @@ class BaseChatLLM:
|
|||||||
"delta":{
|
"delta":{
|
||||||
"content":txt
|
"content":txt
|
||||||
},
|
},
|
||||||
"logprobs":null,
|
"logprobs":None,
|
||||||
"finish_reason":null
|
"finish_reason":None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -71,7 +74,7 @@ class BaseChatLLM:
|
|||||||
yield {
|
yield {
|
||||||
"id":id,
|
"id":id,
|
||||||
"object":"chat.completion.chunk",
|
"object":"chat.completion.chunk",
|
||||||
"created":time.time(),
|
"created":time(),
|
||||||
"model":self.model_id,
|
"model":self.model_id,
|
||||||
"response_time": t2 - t1,
|
"response_time": t2 - t1,
|
||||||
"finish_time": t3 - t1,
|
"finish_time": t3 - t1,
|
||||||
@ -82,7 +85,7 @@ class BaseChatLLM:
|
|||||||
"delta":{
|
"delta":{
|
||||||
"content":""
|
"content":""
|
||||||
},
|
},
|
||||||
"logprobs":null,
|
"logprobs":None,
|
||||||
"finish_reason":"stop"
|
"finish_reason":"stop"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -55,5 +55,5 @@ class DevstralLLM(T2TChatLLM):
|
|||||||
'input_ids': torch.tensor([tokenized.tokens])
|
'input_ids': torch.tensor([tokenized.tokens])
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_register('/share/models/mistralai/Devstral-Small-2505', DevstralLLM)
|
llm_register('mistralai/Devstral', DevstralLLM)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ class Gemma3LLM(MMChatLLM):
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM)
|
llm_register("google/gemma-3", Gemma3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||||
|
@ -29,7 +29,7 @@ class MedgemmaLLM(MMChatLLM):
|
|||||||
).to(self.model.device, dtype=torch.bfloat16)
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM)
|
llm_register("google/medgemma", MedgemmaLLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
||||||
|
@ -39,7 +39,7 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
)
|
)
|
||||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM)
|
llm_register("Qwen/Qwen3", Qwen3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
||||||
|
Loading…
Reference in New Issue
Block a user