This commit is contained in:
yumoqing 2025-06-19 13:40:07 +08:00
parent 8cf533b4a4
commit b6219805e6
2 changed files with 8 additions and 0 deletions

View File

@ -20,6 +20,11 @@ def get_llm_class(model_path):
return None
class BaseChatLLM:
def use_mps_if_prosible(self):
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(devoce)
def get_session_key(self):
return self.model_id + ':messages'

View File

@ -17,6 +17,9 @@ class Qwen3LLM(T2TChatLLM):
torch_dtype="auto",
device_map="auto"
)
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(devoce)
self.model_id = model_id
def build_kwargs(self, inputs, streamer):