Skip to content

Commit 4956aa6

Browse files
authored
Merge pull request xinnan-tech#1399 from xinnan-tech/hot-fix
update:增强TTS流会话关闭和开启逻辑
2 parents 87a9425 + fd17465 commit 4956aa6

File tree

1 file changed

+62
-17
lines changed

1 file changed

+62
-17
lines changed

main/xiaozhi-server/core/providers/tts/huoshan_double_stream.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def __init__(self, config, delete_audio_file):
161161
)
162162
check_model_key("TTS", self.access_token)
163163

164+
# 添加会话状态控制
165+
self._session_lock = asyncio.Lock() # 会话操作的并发锁
166+
self._current_session_id = None # 当前会话ID
167+
self._session_started = False # 会话是否已开始
168+
self._session_finished = False # 会话是否已结束
169+
164170
###################################################################################
165171
# 火山双流式TTS重写父类的方法--开始
166172
###################################################################################
@@ -444,27 +450,66 @@ async def finish_connection(self):
444450
return
445451

446452
async def start_session(self, session_id):
447-
header = Header(
448-
message_type=FULL_CLIENT_REQUEST,
449-
message_type_specific_flags=MsgTypeFlagWithEvent,
450-
serial_method=JSON,
451-
).as_bytes()
452-
optional = Optional(event=EVENT_StartSession, sessionId=session_id).as_bytes()
453-
payload = self.get_payload_bytes(event=EVENT_StartSession, speaker=self.speaker)
454-
await self.send_event(header, optional, payload)
455453
logger.bind(tag=TAG).debug(f"开始会话~~{session_id}")
454+
async with self._session_lock:
455+
# 如果已有会话未结束,先关闭它
456+
if self._session_started and not self._session_finished:
457+
logger.bind(tag=TAG).warning(
458+
f"发现未关闭的会话 {self._current_session_id},正在关闭..."
459+
)
460+
await self.finish_session(self._current_session_id)
461+
462+
# 重置会话状态
463+
self._current_session_id = session_id
464+
self._session_started = True
465+
self._session_finished = False
466+
467+
header = Header(
468+
message_type=FULL_CLIENT_REQUEST,
469+
message_type_specific_flags=MsgTypeFlagWithEvent,
470+
serial_method=JSON,
471+
).as_bytes()
472+
optional = Optional(
473+
event=EVENT_StartSession, sessionId=session_id
474+
).as_bytes()
475+
payload = self.get_payload_bytes(
476+
event=EVENT_StartSession, speaker=self.speaker
477+
)
478+
await self.send_event(header, optional, payload)
456479

457480
async def finish_session(self, session_id):
458481
logger.bind(tag=TAG).debug(f"关闭会话~~{session_id}")
459-
header = Header(
460-
message_type=FULL_CLIENT_REQUEST,
461-
message_type_specific_flags=MsgTypeFlagWithEvent,
462-
serial_method=JSON,
463-
).as_bytes()
464-
optional = Optional(event=EVENT_FinishSession, sessionId=session_id).as_bytes()
465-
payload = str.encode("{}")
466-
await self.send_event(header, optional, payload)
467-
return
482+
async with self._session_lock:
483+
# 检查会话状态
484+
if not self._session_started:
485+
logger.bind(tag=TAG).warning(f"尝试关闭未开始的会话 {session_id}")
486+
return
487+
488+
if self._session_finished:
489+
logger.bind(tag=TAG).warning(f"会话 {session_id} 已经关闭")
490+
return
491+
492+
if self._current_session_id != session_id:
493+
logger.bind(tag=TAG).warning(
494+
f"尝试关闭错误的会话 {session_id},当前会话为 {self._current_session_id}"
495+
)
496+
return
497+
498+
header = Header(
499+
message_type=FULL_CLIENT_REQUEST,
500+
message_type_specific_flags=MsgTypeFlagWithEvent,
501+
serial_method=JSON,
502+
).as_bytes()
503+
optional = Optional(
504+
event=EVENT_FinishSession, sessionId=session_id
505+
).as_bytes()
506+
payload = str.encode("{}")
507+
await self.send_event(header, optional, payload)
508+
509+
# 更新会话状态
510+
self._session_finished = True
511+
self._session_started = False
512+
self._current_session_id = None
468513

469514
async def reset(self):
470515
# 关闭之前的对话

0 commit comments

Comments
 (0)