webrtc_aec_demo.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. """WebRTC回声消除(AEC)演示脚本.
  2. 该脚本用于演示WebRTC APM库的回声消除功能:
  3. 1. 播放指定的音频文件(作为参考信号)
  4. 2. 同时录制麦克风输入(包含回声和环境声音)
  5. 3. 应用WebRTC回声消除处理
  6. 4. 保存原始录音和处理后的录音,以便比较
  7. 用法:
  8. python webrtc_aec_demo.py [音频文件路径]
  9. 示例:
  10. python webrtc_aec_demo.py 鞠婧祎.wav
  11. """
  12. import ctypes
  13. import os
  14. import sys
  15. import threading
  16. import time
  17. import wave
  18. from ctypes import POINTER, Structure, byref, c_bool, c_float, c_int, c_short, c_void_p
  19. import numpy as np
  20. import pyaudio
  21. import pygame
  22. import soundfile as sf
  23. from pygame import mixer
  24. # 获取DLL文件的绝对路径
  25. current_dir = os.path.dirname(os.path.abspath(__file__))
  26. project_root = os.path.dirname(current_dir)
  27. dll_path = os.path.join(
  28. project_root, "libs", "webrtc_apm", "win", "x86_64", "libwebrtc_apm.dll"
  29. )
  30. # 加载DLL
  31. try:
  32. apm_lib = ctypes.CDLL(dll_path)
  33. print(f"成功加载WebRTC APM库: {dll_path}")
  34. except Exception as e:
  35. print(f"加载WebRTC APM库失败: {e}")
  36. sys.exit(1)
  37. # 定义结构体和枚举类型
  38. class DownmixMethod(ctypes.c_int):
  39. AverageChannels = 0
  40. UseFirstChannel = 1
  41. class NoiseSuppressionLevel(ctypes.c_int):
  42. Low = 0
  43. Moderate = 1
  44. High = 2
  45. VeryHigh = 3
  46. class GainControllerMode(ctypes.c_int):
  47. AdaptiveAnalog = 0
  48. AdaptiveDigital = 1
  49. FixedDigital = 2
  50. class ClippingPredictorMode(ctypes.c_int):
  51. ClippingEventPrediction = 0
  52. AdaptiveStepClippingPeakPrediction = 1
  53. FixedStepClippingPeakPrediction = 2
  54. # 定义Pipeline结构体
  55. class Pipeline(Structure):
  56. _fields_ = [
  57. ("MaximumInternalProcessingRate", c_int),
  58. ("MultiChannelRender", c_bool),
  59. ("MultiChannelCapture", c_bool),
  60. ("CaptureDownmixMethod", c_int),
  61. ]
  62. # 定义PreAmplifier结构体
  63. class PreAmplifier(Structure):
  64. _fields_ = [("Enabled", c_bool), ("FixedGainFactor", c_float)]
  65. # 定义AnalogMicGainEmulation结构体
  66. class AnalogMicGainEmulation(Structure):
  67. _fields_ = [("Enabled", c_bool), ("InitialLevel", c_int)]
  68. # 定义CaptureLevelAdjustment结构体
  69. class CaptureLevelAdjustment(Structure):
  70. _fields_ = [
  71. ("Enabled", c_bool),
  72. ("PreGainFactor", c_float),
  73. ("PostGainFactor", c_float),
  74. ("MicGainEmulation", AnalogMicGainEmulation),
  75. ]
  76. # 定义HighPassFilter结构体
  77. class HighPassFilter(Structure):
  78. _fields_ = [("Enabled", c_bool), ("ApplyInFullBand", c_bool)]
  79. # 定义EchoCanceller结构体
  80. class EchoCanceller(Structure):
  81. _fields_ = [
  82. ("Enabled", c_bool),
  83. ("MobileMode", c_bool),
  84. ("ExportLinearAecOutput", c_bool),
  85. ("EnforceHighPassFiltering", c_bool),
  86. ]
  87. # 定义NoiseSuppression结构体
  88. class NoiseSuppression(Structure):
  89. _fields_ = [
  90. ("Enabled", c_bool),
  91. ("NoiseLevel", c_int),
  92. ("AnalyzeLinearAecOutputWhenAvailable", c_bool),
  93. ]
  94. # 定义TransientSuppression结构体
  95. class TransientSuppression(Structure):
  96. _fields_ = [("Enabled", c_bool)]
  97. # 定义ClippingPredictor结构体
  98. class ClippingPredictor(Structure):
  99. _fields_ = [
  100. ("Enabled", c_bool),
  101. ("PredictorMode", c_int),
  102. ("WindowLength", c_int),
  103. ("ReferenceWindowLength", c_int),
  104. ("ReferenceWindowDelay", c_int),
  105. ("ClippingThreshold", c_float),
  106. ("CrestFactorMargin", c_float),
  107. ("UsePredictedStep", c_bool),
  108. ]
  109. # 定义AnalogGainController结构体
  110. class AnalogGainController(Structure):
  111. _fields_ = [
  112. ("Enabled", c_bool),
  113. ("StartupMinVolume", c_int),
  114. ("ClippedLevelMin", c_int),
  115. ("EnableDigitalAdaptive", c_bool),
  116. ("ClippedLevelStep", c_int),
  117. ("ClippedRatioThreshold", c_float),
  118. ("ClippedWaitFrames", c_int),
  119. ("Predictor", ClippingPredictor),
  120. ]
  121. # 定义GainController1结构体
  122. class GainController1(Structure):
  123. _fields_ = [
  124. ("Enabled", c_bool),
  125. ("ControllerMode", c_int),
  126. ("TargetLevelDbfs", c_int),
  127. ("CompressionGainDb", c_int),
  128. ("EnableLimiter", c_bool),
  129. ("AnalogController", AnalogGainController),
  130. ]
  131. # 定义InputVolumeController结构体
  132. class InputVolumeController(Structure):
  133. _fields_ = [("Enabled", c_bool)]
  134. # 定义AdaptiveDigital结构体
  135. class AdaptiveDigital(Structure):
  136. _fields_ = [
  137. ("Enabled", c_bool),
  138. ("HeadroomDb", c_float),
  139. ("MaxGainDb", c_float),
  140. ("InitialGainDb", c_float),
  141. ("MaxGainChangeDbPerSecond", c_float),
  142. ("MaxOutputNoiseLevelDbfs", c_float),
  143. ]
  144. # 定义FixedDigital结构体
  145. class FixedDigital(Structure):
  146. _fields_ = [("GainDb", c_float)]
  147. # 定义GainController2结构体
  148. class GainController2(Structure):
  149. _fields_ = [
  150. ("Enabled", c_bool),
  151. ("VolumeController", InputVolumeController),
  152. ("AdaptiveController", AdaptiveDigital),
  153. ("FixedController", FixedDigital),
  154. ]
  155. # 定义完整的Config结构体
  156. class Config(Structure):
  157. _fields_ = [
  158. ("PipelineConfig", Pipeline),
  159. ("PreAmp", PreAmplifier),
  160. ("LevelAdjustment", CaptureLevelAdjustment),
  161. ("HighPass", HighPassFilter),
  162. ("Echo", EchoCanceller),
  163. ("NoiseSuppress", NoiseSuppression),
  164. ("TransientSuppress", TransientSuppression),
  165. ("GainControl1", GainController1),
  166. ("GainControl2", GainController2),
  167. ]
  168. # 定义DLL函数原型
  169. apm_lib.WebRTC_APM_Create.restype = c_void_p
  170. apm_lib.WebRTC_APM_Create.argtypes = []
  171. apm_lib.WebRTC_APM_Destroy.restype = None
  172. apm_lib.WebRTC_APM_Destroy.argtypes = [c_void_p]
  173. apm_lib.WebRTC_APM_CreateStreamConfig.restype = c_void_p
  174. apm_lib.WebRTC_APM_CreateStreamConfig.argtypes = [c_int, c_int]
  175. apm_lib.WebRTC_APM_DestroyStreamConfig.restype = None
  176. apm_lib.WebRTC_APM_DestroyStreamConfig.argtypes = [c_void_p]
  177. apm_lib.WebRTC_APM_ApplyConfig.restype = c_int
  178. apm_lib.WebRTC_APM_ApplyConfig.argtypes = [c_void_p, POINTER(Config)]
  179. apm_lib.WebRTC_APM_ProcessReverseStream.restype = c_int
  180. apm_lib.WebRTC_APM_ProcessReverseStream.argtypes = [
  181. c_void_p,
  182. POINTER(c_short),
  183. c_void_p,
  184. c_void_p,
  185. POINTER(c_short),
  186. ]
  187. apm_lib.WebRTC_APM_ProcessStream.restype = c_int
  188. apm_lib.WebRTC_APM_ProcessStream.argtypes = [
  189. c_void_p,
  190. POINTER(c_short),
  191. c_void_p,
  192. c_void_p,
  193. POINTER(c_short),
  194. ]
  195. apm_lib.WebRTC_APM_SetStreamDelayMs.restype = None
  196. apm_lib.WebRTC_APM_SetStreamDelayMs.argtypes = [c_void_p, c_int]
  197. def create_apm_config():
  198. """创建WebRTC APM配置 - 优化为保留自然语音,减少错误码-11问题"""
  199. config = Config()
  200. # 设置Pipeline配置 - 使用标准采样率避免重采样问题
  201. config.PipelineConfig.MaximumInternalProcessingRate = 16000 # WebRTC优化频率
  202. config.PipelineConfig.MultiChannelRender = False
  203. config.PipelineConfig.MultiChannelCapture = False
  204. config.PipelineConfig.CaptureDownmixMethod = DownmixMethod.AverageChannels
  205. # 设置PreAmplifier配置 - 减少预放大干扰
  206. config.PreAmp.Enabled = False # 关闭预放大,避免失真
  207. config.PreAmp.FixedGainFactor = 1.0 # 不增益
  208. # 设置LevelAdjustment配置 - 简化电平调整
  209. config.LevelAdjustment.Enabled = False # 禁用电平调整以减少处理冲突
  210. config.LevelAdjustment.PreGainFactor = 1.0
  211. config.LevelAdjustment.PostGainFactor = 1.0
  212. config.LevelAdjustment.MicGainEmulation.Enabled = False
  213. config.LevelAdjustment.MicGainEmulation.InitialLevel = 100 # 降低初始电平避免过饱和
  214. # 设置HighPassFilter配置 - 使用标准高通滤波
  215. config.HighPass.Enabled = True # 启用高通滤波器移除低频噪声
  216. config.HighPass.ApplyInFullBand = True # 在全频段应用,更好的兼容性
  217. # 设置EchoCanceller配置 - 优化回声消除
  218. config.Echo.Enabled = True # 启用回声消除
  219. config.Echo.MobileMode = False # 使用标准模式而非移动模式以获取更好效果
  220. config.Echo.ExportLinearAecOutput = False
  221. config.Echo.EnforceHighPassFiltering = True # 启用强制高通滤波,帮助消除低频回声
  222. # 设置NoiseSuppression配置 - 中等强度噪声抑制
  223. config.NoiseSuppress.Enabled = True
  224. config.NoiseSuppress.NoiseLevel = NoiseSuppressionLevel.Moderate # 中等级别抑制
  225. config.NoiseSuppress.AnalyzeLinearAecOutputWhenAvailable = True
  226. # 设置TransientSuppression配置
  227. config.TransientSuppress.Enabled = False # 关闭瞬态抑制,避免切割语音
  228. # 设置GainController1配置 - 轻度增益控制
  229. config.GainControl1.Enabled = True # 启用增益控制
  230. config.GainControl1.ControllerMode = GainControllerMode.AdaptiveDigital
  231. config.GainControl1.TargetLevelDbfs = 3 # 降低目标电平(更积极的控制)
  232. config.GainControl1.CompressionGainDb = 9 # 适中的压缩增益
  233. config.GainControl1.EnableLimiter = True # 启用限制器
  234. # AnalogGainController
  235. config.GainControl1.AnalogController.Enabled = False # 关闭模拟增益控制
  236. config.GainControl1.AnalogController.StartupMinVolume = 0
  237. config.GainControl1.AnalogController.ClippedLevelMin = 70
  238. config.GainControl1.AnalogController.EnableDigitalAdaptive = False
  239. config.GainControl1.AnalogController.ClippedLevelStep = 15
  240. config.GainControl1.AnalogController.ClippedRatioThreshold = 0.1
  241. config.GainControl1.AnalogController.ClippedWaitFrames = 300
  242. # ClippingPredictor
  243. predictor = config.GainControl1.AnalogController.Predictor
  244. predictor.Enabled = False
  245. predictor.PredictorMode = ClippingPredictorMode.ClippingEventPrediction
  246. predictor.WindowLength = 5
  247. predictor.ReferenceWindowLength = 5
  248. predictor.ReferenceWindowDelay = 5
  249. predictor.ClippingThreshold = -1.0
  250. predictor.CrestFactorMargin = 3.0
  251. predictor.UsePredictedStep = True
  252. # 设置GainController2配置 - 禁用以避免冲突
  253. config.GainControl2.Enabled = False
  254. config.GainControl2.VolumeController.Enabled = False
  255. config.GainControl2.AdaptiveController.Enabled = False
  256. config.GainControl2.AdaptiveController.HeadroomDb = 5.0
  257. config.GainControl2.AdaptiveController.MaxGainDb = 30.0
  258. config.GainControl2.AdaptiveController.InitialGainDb = 15.0
  259. config.GainControl2.AdaptiveController.MaxGainChangeDbPerSecond = 6.0
  260. config.GainControl2.AdaptiveController.MaxOutputNoiseLevelDbfs = -50.0
  261. config.GainControl2.FixedController.GainDb = 0.0
  262. return config
  263. # 参考音频缓冲区(用于存储扬声器输出)
  264. reference_buffer = []
  265. reference_lock = threading.Lock()
  266. def record_playback_audio(chunk_size, sample_rate, channels):
  267. """
  268. 录制扬声器输出的音频(更准确的参考信号)
  269. """
  270. global reference_buffer
  271. # 注:这是理想情况下的实现,但Windows下PyAudio通常无法直接录制扬声器输出
  272. # 实际应用中,需要使用其他方法捕获系统音频输出
  273. try:
  274. p = pyaudio.PyAudio()
  275. # 尝试创建一个从默认输出设备录制的流(部分系统支持)
  276. # 注意:这在大多数系统上不起作用,这里只是作为示例
  277. loopback_stream = p.open(
  278. format=pyaudio.paInt16,
  279. channels=channels,
  280. rate=sample_rate,
  281. input=True,
  282. frames_per_buffer=chunk_size,
  283. input_device_index=None, # 尝试使用默认输出设备作为输入源
  284. )
  285. # 开始录制
  286. while True:
  287. try:
  288. data = loopback_stream.read(chunk_size, exception_on_overflow=False)
  289. with reference_lock:
  290. reference_buffer.append(data)
  291. except OSError:
  292. break
  293. # 保持缓冲区大小合理
  294. with reference_lock:
  295. if len(reference_buffer) > 100: # 保持约2秒的缓冲
  296. reference_buffer = reference_buffer[-100:]
  297. except Exception as e:
  298. print(f"无法录制系统音频: {e}")
  299. finally:
  300. try:
  301. if "loopback_stream" in locals() and loopback_stream:
  302. loopback_stream.stop_stream()
  303. loopback_stream.close()
  304. if "p" in locals() and p:
  305. p.terminate()
  306. except Exception:
  307. pass
  308. def aec_demo(audio_file):
  309. """
  310. WebRTC回声消除演示主函数.
  311. """
  312. # 检查音频文件是否存在
  313. if not os.path.exists(audio_file):
  314. print(f"错误: 找不到音频文件 {audio_file}")
  315. return
  316. # 音频参数设置 - 使用WebRTC优化的音频参数
  317. SAMPLE_RATE = 16000 # 采样率16kHz (WebRTC AEC优化采样率)
  318. CHANNELS = 1 # 单声道
  319. CHUNK = 160 # 每帧样本数(10ms @ 16kHz,WebRTC的标准帧大小)
  320. FORMAT = pyaudio.paInt16 # 16位PCM格式
  321. # 初始化PyAudio
  322. p = pyaudio.PyAudio()
  323. # 列出所有可用的音频设备信息供参考
  324. print("\n可用音频设备:")
  325. for i in range(p.get_device_count()):
  326. dev_info = p.get_device_info_by_index(i)
  327. print(f"设备 {i}: {dev_info['name']}")
  328. print(f" - 输入通道: {dev_info['maxInputChannels']}")
  329. print(f" - 输出通道: {dev_info['maxOutputChannels']}")
  330. print(f" - 默认采样率: {dev_info['defaultSampleRate']}")
  331. print("")
  332. # 打开麦克风输入流
  333. input_stream = p.open(
  334. format=FORMAT,
  335. channels=CHANNELS,
  336. rate=SAMPLE_RATE,
  337. input=True,
  338. frames_per_buffer=CHUNK,
  339. )
  340. # 初始化pygame用于播放音频
  341. pygame.init()
  342. mixer.init(frequency=SAMPLE_RATE, size=-16, channels=CHANNELS, buffer=CHUNK * 4)
  343. # 加载参考音频文件
  344. print(f"加载音频文件: {audio_file}")
  345. # 读取参考音频文件并转换采样率/通道数
  346. # 注意:这里使用soundfile库加载音频文件以支持多种格式并进行重采样
  347. try:
  348. print("加载参考音频...")
  349. # 使用soundfile库读取原始音频
  350. ref_audio_data, orig_sr = sf.read(audio_file, dtype="int16")
  351. print(
  352. f"原始音频: 采样率={orig_sr}, 通道数="
  353. f"{ref_audio_data.shape[1] if len(ref_audio_data.shape) > 1 else 1}"
  354. )
  355. # 转换为单声道(如果是立体声)
  356. if len(ref_audio_data.shape) > 1 and ref_audio_data.shape[1] > 1:
  357. ref_audio_data = ref_audio_data.mean(axis=1).astype(np.int16)
  358. # 转换采样率(如果需要)
  359. if orig_sr != SAMPLE_RATE:
  360. print(f"重采样参考音频从{orig_sr}Hz到{SAMPLE_RATE}Hz...")
  361. # 使用librosa或scipy进行重采样
  362. from scipy import signal
  363. ref_audio_data = signal.resample(
  364. ref_audio_data, int(len(ref_audio_data) * SAMPLE_RATE / orig_sr)
  365. ).astype(np.int16)
  366. # 保存为临时wav文件供pygame播放
  367. temp_wav_path = os.path.join(current_dir, "temp_reference.wav")
  368. with wave.open(temp_wav_path, "wb") as wf:
  369. wf.setnchannels(1)
  370. wf.setsampwidth(2) # 2字节(16位)
  371. wf.setframerate(SAMPLE_RATE)
  372. wf.writeframes(ref_audio_data.tobytes())
  373. # 将参考音频分成CHUNK大小的帧
  374. ref_audio_frames = []
  375. for i in range(0, len(ref_audio_data), CHUNK):
  376. if i + CHUNK <= len(ref_audio_data):
  377. ref_audio_frames.append(ref_audio_data[i : i + CHUNK])
  378. else:
  379. # 最后一帧不足CHUNK大小,补零
  380. last_frame = np.zeros(CHUNK, dtype=np.int16)
  381. last_frame[: len(ref_audio_data) - i] = ref_audio_data[i:]
  382. ref_audio_frames.append(last_frame)
  383. print(f"参考音频准备完成,共{len(ref_audio_frames)}帧")
  384. # 加载处理后的临时WAV文件
  385. mixer.music.load(temp_wav_path)
  386. except Exception as e:
  387. print(f"加载参考音频时出错: {e}")
  388. sys.exit(1)
  389. # 创建WebRTC APM实例
  390. apm = apm_lib.WebRTC_APM_Create()
  391. # 应用APM配置
  392. config = create_apm_config()
  393. result = apm_lib.WebRTC_APM_ApplyConfig(apm, byref(config))
  394. if result != 0:
  395. print(f"警告: APM配置应用失败,错误码: {result}")
  396. # 创建流配置
  397. stream_config = apm_lib.WebRTC_APM_CreateStreamConfig(SAMPLE_RATE, CHANNELS)
  398. # 设置较小的延迟时间以更准确匹配参考信号和麦克风信号
  399. apm_lib.WebRTC_APM_SetStreamDelayMs(apm, 50)
  400. # 创建录音缓冲区
  401. original_frames = []
  402. processed_frames = []
  403. reference_frames = []
  404. # 等待一会让音频系统准备好
  405. time.sleep(0.5)
  406. print("开始录制和处理...")
  407. print("播放参考音频...")
  408. mixer.music.play()
  409. # 录制持续时间(根据音频文件长度)
  410. try:
  411. sound_length = mixer.Sound(temp_wav_path).get_length()
  412. recording_time = sound_length if sound_length > 0 else 10
  413. except Exception:
  414. recording_time = 10 # 如果无法获取长度,默认10秒
  415. recording_time += 1 # 额外1秒确保捕获所有音频
  416. start_time = time.time()
  417. current_ref_frame_index = 0
  418. try:
  419. while time.time() - start_time < recording_time:
  420. # 从麦克风读取一帧数据
  421. input_data = input_stream.read(CHUNK, exception_on_overflow=False)
  422. # 保存原始录音
  423. original_frames.append(input_data)
  424. # 将输入数据转换为short数组
  425. input_array = np.frombuffer(input_data, dtype=np.int16)
  426. input_ptr = input_array.ctypes.data_as(POINTER(c_short))
  427. # 获取当前参考音频帧
  428. if current_ref_frame_index < len(ref_audio_frames):
  429. ref_array = ref_audio_frames[current_ref_frame_index]
  430. reference_frames.append(ref_array.tobytes())
  431. current_ref_frame_index += 1
  432. else:
  433. # 如果参考音频播放完毕,使用静音帧
  434. ref_array = np.zeros(CHUNK, dtype=np.int16)
  435. reference_frames.append(ref_array.tobytes())
  436. ref_ptr = ref_array.ctypes.data_as(POINTER(c_short))
  437. # 创建输出缓冲区
  438. output_array = np.zeros(CHUNK, dtype=np.int16)
  439. output_ptr = output_array.ctypes.data_as(POINTER(c_short))
  440. # 重要:先处理参考信号(扬声器输出)
  441. # 创建参考信号的输出缓冲区(虽然不使用但必须提供)
  442. ref_output_array = np.zeros(CHUNK, dtype=np.int16)
  443. ref_output_ptr = ref_output_array.ctypes.data_as(POINTER(c_short))
  444. result_reverse = apm_lib.WebRTC_APM_ProcessReverseStream(
  445. apm, ref_ptr, stream_config, stream_config, ref_output_ptr
  446. )
  447. if result_reverse != 0:
  448. print(f"\r警告: 参考信号处理失败,错误码: {result_reverse}")
  449. # 然后处理麦克风信号,应用回声消除
  450. result = apm_lib.WebRTC_APM_ProcessStream(
  451. apm, input_ptr, stream_config, stream_config, output_ptr
  452. )
  453. if result != 0:
  454. print(f"\r警告: 处理失败,错误码: {result}")
  455. # 保存处理后的音频帧
  456. processed_frames.append(output_array.tobytes())
  457. # 计算并显示进度
  458. progress = (time.time() - start_time) / recording_time * 100
  459. sys.stdout.write(f"\r处理进度: {progress:.1f}%")
  460. sys.stdout.flush()
  461. except KeyboardInterrupt:
  462. print("\n录制被用户中断")
  463. finally:
  464. print("\n录制和处理完成")
  465. # 停止播放
  466. mixer.music.stop()
  467. # 关闭音频流
  468. input_stream.stop_stream()
  469. input_stream.close()
  470. # 释放APM资源
  471. apm_lib.WebRTC_APM_DestroyStreamConfig(stream_config)
  472. apm_lib.WebRTC_APM_Destroy(apm)
  473. # 关闭PyAudio
  474. p.terminate()
  475. # 保存原始录音
  476. original_output_path = os.path.join(current_dir, "original_recording.wav")
  477. save_wav(original_output_path, original_frames, SAMPLE_RATE, CHANNELS)
  478. # 保存处理后的录音
  479. processed_output_path = os.path.join(current_dir, "processed_recording.wav")
  480. save_wav(processed_output_path, processed_frames, SAMPLE_RATE, CHANNELS)
  481. # 保存参考音频(播放的音频)
  482. reference_output_path = os.path.join(current_dir, "reference_playback.wav")
  483. save_wav(reference_output_path, reference_frames, SAMPLE_RATE, CHANNELS)
  484. # 删除临时文件
  485. if os.path.exists(temp_wav_path):
  486. try:
  487. os.remove(temp_wav_path)
  488. except Exception:
  489. pass
  490. print(f"原始录音已保存至: {original_output_path}")
  491. print(f"处理后的录音已保存至: {processed_output_path}")
  492. print(f"参考音频已保存至: {reference_output_path}")
  493. # 退出pygame
  494. pygame.quit()
  495. def save_wav(file_path, frames, sample_rate, channels):
  496. """
  497. 将音频帧保存为WAV文件.
  498. """
  499. with wave.open(file_path, "wb") as wf:
  500. wf.setnchannels(channels)
  501. wf.setsampwidth(2) # 2字节(16位)
  502. wf.setframerate(sample_rate)
  503. if isinstance(frames[0], bytes):
  504. wf.writeframes(b"".join(frames))
  505. else:
  506. wf.writeframes(b"".join([f for f in frames if isinstance(f, bytes)]))
  507. if __name__ == "__main__":
  508. # 获取命令行参数
  509. if len(sys.argv) > 1:
  510. audio_file = sys.argv[1]
  511. else:
  512. # 默认使用scripts目录下的鞠婧祎.wav
  513. audio_file = os.path.join(current_dir, "鞠婧祎.wav")
  514. # 如果默认文件不存在,尝试MP3版本
  515. if not os.path.exists(audio_file):
  516. audio_file = os.path.join(current_dir, "鞠婧祎.mp3")
  517. if not os.path.exists(audio_file):
  518. print("错误: 找不到默认音频文件,请指定要播放的音频文件路径")
  519. print("用法: python webrtc_aec_demo.py [音频文件路径]")
  520. sys.exit(1)
  521. # 运行演示
  522. aec_demo(audio_file)