diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index ec659c6..1559b86 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -20,6 +20,11 @@ def get_llm_class(model_path): return None class BaseChatLLM: + def use_mps_if_prosible(self): + if torch.backends.mps.is_available(): + device = torch.device("mps") + self.model = self.model.to(devoce) + def get_session_key(self): return self.model_id + ':messages' diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index c536d4c..7e7c24a 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -17,6 +17,9 @@ class Qwen3LLM(T2TChatLLM): torch_dtype="auto", device_map="auto" ) + if torch.backends.mps.is_available(): + device = torch.device("mps") + self.model = self.model.to(devoce) self.model_id = model_id def build_kwargs(self, inputs, streamer):