llmengine/llmengine/ltpentity.py

76 lines
2.6 KiB
Python

# Requires ltp>=0.2.0
from ltp import LTP
from typing import List
import logging
from llmengine.base_entity import BaseLtp, ltp_register
logger = logging.getLogger(__name__)
class LtpEntity(BaseLtp):
def __init__(self, model_id):
# Load LTP model for CWS, POS, and NER
self.ltp = LTP(model_id)
self.model_id = model_id
self.model_name = model_id.split('/')[-1]
def extract_entities(self, query: str) -> List[str]:
"""
从查询文本中抽取实体,包括:
- LTP NER 识别的实体(所有类型)。
- LTP POS 标注为名词('n')的词。
- LTP POS 标注为动词('v')的词。
- 连续名词合并(如 '苹果 公司' -> '苹果公司'),移除子词。
"""
try:
if not query:
raise ValueError("查询文本不能为空")
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
words = result.cws[0]
pos_list = result.pos[0]
ner = result.ner[0]
entities = []
subword_set = set()
logger.debug(f"NER 结果: {ner}")
for entity_type, entity, start, end in ner:
entities.append(entity)
combined = ""
combined_words = []
for i in range(len(words)):
if pos_list[i] == 'n':
combined += words[i]
combined_words.append(words[i])
if i + 1 < len(words) and pos_list[i + 1] == 'n':
continue
if combined:
entities.append(combined)
subword_set.update(combined_words)
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
combined = ""
combined_words = []
else:
combined = ""
combined_words = []
logger.debug(f"连续名词子词集合: {subword_set}")
for word, pos in zip(words, pos_list):
if pos == 'n' and word not in subword_set:
entities.append(word)
for word, pos in zip(words, pos_list):
if pos == 'v':
entities.append(word)
unique_entities = list(dict.fromkeys(entities))
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities
except Exception as e:
logger.error(f"实体抽取失败: {str(e)}")
return []
ltp_register('LTP', LtpEntity)