bugfix
This commit is contained in:
parent
8cf533b4a4
commit
b6219805e6
@ -20,6 +20,11 @@ def get_llm_class(model_path):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
class BaseChatLLM:
|
class BaseChatLLM:
|
||||||
|
def use_mps_if_prosible(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(devoce)
|
||||||
|
|
||||||
def get_session_key(self):
|
def get_session_key(self):
|
||||||
return self.model_id + ':messages'
|
return self.model_id + ':messages'
|
||||||
|
|
||||||
|
@ -17,6 +17,9 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
torch_dtype="auto",
|
torch_dtype="auto",
|
||||||
device_map="auto"
|
device_map="auto"
|
||||||
)
|
)
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(devoce)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
Loading…
Reference in New Issue
Block a user