bugfix
This commit is contained in:
parent
8cf533b4a4
commit
b6219805e6
@ -20,6 +20,11 @@ def get_llm_class(model_path):
|
||||
return None
|
||||
|
||||
class BaseChatLLM:
|
||||
def use_mps_if_prosible(self):
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
self.model = self.model.to(devoce)
|
||||
|
||||
def get_session_key(self):
|
||||
return self.model_id + ':messages'
|
||||
|
||||
|
@ -17,6 +17,9 @@ class Qwen3LLM(T2TChatLLM):
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
self.model = self.model.to(devoce)
|
||||
self.model_id = model_id
|
||||
|
||||
def build_kwargs(self, inputs, streamer):
|
||||
|
Loading…
Reference in New Issue
Block a user