Skip to content

Commit

Permalink
Export gtcrn models to sherpa-onnx (#1975)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Mar 10, 2025
1 parent 362ddf2 commit 6e261ed
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 0 deletions.
103 changes: 103 additions & 0 deletions .github/workflows/export-gtcrn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
name: export-gtcrn-to-onnx

on:
push:
branches:
- export-gtcrn

workflow_dispatch:

concurrency:
group: export-gtcrn-to-onnx-${{ github.ref }}
cancel-in-progress: true

jobs:
export-gtcrn-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export gtcrn ${{ matrix.version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]

steps:
- uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Python dependencies
shell: bash
run: |
pip install "numpy<=1.26.4" onnx==1.16.0 onnxruntime==1.17.1 librosa soundfile torch==2.6.0+cpu -f https://download.pytorch.org/whl/torch "kaldi-native-fbank>=1.21.1"
- name: Run
shell: bash
run: |
cd scripts/gtcrn
./run.sh
./test.py
ls -lh
- name: Collect results
shell: bash
run: |
src=scripts/gtcrn
cp -v $src/*.onnx ./
ls -lh *.onnx
- name: Publish to huggingface 0.19
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "[email protected]"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://csukuangfj:[email protected]/csukuangfj/speech-enhancement-models huggingface
cd huggingface
git fetch
git pull
cp -v ../gtcrn_simple.onnx ./
git lfs track "*.onnx"
git add .
ls -lh
git status
git commit -m "add models"
git push https://csukuangfj:[email protected]/csukuangfj/speech-enhancement-models main || true
- name: Release
if: github.repository_owner == 'csukuangfj'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.onnx
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: speech-enhancement-models

- name: Release
if: github.repository_owner == 'k2-fsa'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.onnx
overwrite: true
tag: speech-enhancement-models
4 changes: 4 additions & 0 deletions scripts/gtcrn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Introduction

This folder contains scripts for adding metadata to models from
https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx
72 changes: 72 additions & 0 deletions scripts/gtcrn/add_meta_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)

"""
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
-----
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
"""

import onnx
import onnxruntime as ort


def show(filename):
session_opts = ort.SessionOptions()
session_opts.log_severity_level = 3
sess = ort.InferenceSession(filename, session_opts)
for i in sess.get_inputs():
print(i)

print("-----")

for i in sess.get_outputs():
print(i)


def main():
filename = "./gtcrn_simple.onnx"
show(filename)
model = onnx.load(filename)

meta_data = {
"model_type": "gtcrn",
"comment": "gtcrn_simple",
"version": 1,
"sample_rate": 16000,
"model_url": "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx",
"maintainer": "k2-fsa",
"comment2": "Please see also https://github.com/Xiaobin-Rong/gtcrn",
"conv_cache_shape": "2,1,16,16,33",
"tra_cache_shape": "2,3,1,1,16",
"inter_cache_shape": "2,1,33,16",
"n_fft": 512,
"hop_length": 256,
"window_length": 512,
"window_type": "hann_sqrt",
}

print(model.metadata_props)

while len(model.metadata_props):
model.metadata_props.pop()

for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
print("--------------------")

print(model.metadata_props)

onnx.save(model, filename)


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions scripts/gtcrn/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env bash
#

if [ ! -f gtcrn_simple.onnx ]; then
wget https://github.com/Xiaobin-Rong/gtcrn/raw/refs/heads/main/stream/onnx_models/gtcrn_simple.onnx
fi

if [ ! -f ./inp_16k.wav ]; then
wget https://github.com/yuyun2000/SpeechDenoiser/raw/refs/heads/main/16k/inp_16k.wav
fi

python3 ./add_meta_data.py
136 changes: 136 additions & 0 deletions scripts/gtcrn/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)

from typing import Tuple

import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch


def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate


class OnnxModel:
def __init__(self):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1

self.session_opts = session_opts
self.model = ort.InferenceSession(
"./gtcrn_simple.onnx",
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

meta = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(meta["sample_rate"])
self.n_fft = int(meta["n_fft"])
self.hop_length = int(meta["hop_length"])
self.window_length = int(meta["window_length"])
assert meta["window_type"] == "hann_sqrt", meta["window_type"]

self.window = torch.hann_window(self.window_length).pow(0.5)

def get_init_states(self):
meta = self.model.get_modelmeta().custom_metadata_map
conv_cache_shape = list(map(int, meta["conv_cache_shape"].split(",")))
tra_cache_shape = list(map(int, meta["tra_cache_shape"].split(",")))
inter_cache_shape = list(map(int, meta["inter_cache_shape"].split(",")))

conv_cache_shape = np.zeros(conv_cache_shape, dtype=np.float32)
tra_cache = np.zeros(tra_cache_shape, dtype=np.float32)
inter_cache = np.zeros(inter_cache_shape, dtype=np.float32)

return conv_cache_shape, tra_cache, inter_cache

def __call__(self, x, states):
"""
Args:
x: (1, n_fft/2+1, 1, 2)
Returns:
o: (1, n_fft/2+1, 1, 2)
"""
out, next_conv_cache, next_tra_cache, next_inter_cache = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
self.model.get_outputs()[2].name,
self.model.get_outputs()[3].name,
],
{
self.model.get_inputs()[0].name: x,
self.model.get_inputs()[1].name: states[0],
self.model.get_inputs()[2].name: states[1],
self.model.get_inputs()[3].name: states[2],
},
)

return out, (next_conv_cache, next_tra_cache, next_inter_cache)


def main():
model = OnnxModel()

filename = "./inp_16k.wav"
wave, sample_rate = load_audio(filename)
if sample_rate != model.sample_rate:
import librosa

wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=model.sample_rate)
sample_rate = model.sample_rate

stft_config = knf.StftConfig(
n_fft=model.n_fft,
hop_length=model.hop_length,
win_length=model.window_length,
window=model.window.tolist(),
)
stft = knf.Stft(stft_config)
stft_result = stft(wave)
num_frames = stft_result.num_frames
real = np.array(stft_result.real, dtype=np.float32).reshape(num_frames, -1)
imag = np.array(stft_result.imag, dtype=np.float32).reshape(num_frames, -1)

states = model.get_init_states()
outputs = []
for i in range(num_frames):
x_real = real[i : i + 1]
x_imag = imag[i : i + 1]
x = np.vstack([x_real, x_imag]).transpose()
x = np.expand_dims(x, axis=0)
x = np.expand_dims(x, axis=2)

o, states = model(x, states)
outputs.append(o)

outputs = np.concatenate(outputs, axis=2)
outputs = outputs.squeeze(0).transpose(1, 0, 2)

enhanced_real = outputs[:, :, 0]
enhanced_imag = outputs[:, :, 1]
enhanced_stft_result = knf.StftResult(
real=enhanced_real.reshape(-1).tolist(),
imag=enhanced_imag.reshape(-1).tolist(),
num_frames=enhanced_real.shape[0],
)

istft = knf.IStft(stft_config)
enhanced = istft(enhanced_stft_result)

sf.write("./enhanced_16k.wav", enhanced, model.sample_rate)


if __name__ == "__main__":
main()

0 comments on commit 6e261ed

Please sign in to comment.