Skip to content

Commit 026f053

Browse files
Alex-Wenggclaude
andcommitted
Fix Kokoro ZH model shape mismatch and add mixed language support
- Fix check_array_shape to properly detect MLX vs PyTorch conv weight formats - Update weight sanitization in kokoro.py and istftnet.py to use format detection - Add Chinese-to-Bopomofo conversion using pypinyin for ZH model compatibility - Add number-to-Chinese conversion for proper TTS of numeric content - Add mixed Chinese/English text processing in pipeline - Update tests for check_array_shape function Fixes #226 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 1153d67 commit 026f053

File tree

6 files changed

+319
-52
lines changed

6 files changed

+319
-52
lines changed

mlx_audio/base.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,68 @@ def from_dict(cls, params):
1616

1717

1818
def check_array_shape(arr):
19+
"""
20+
Check if a conv weight array is already in MLX format.
21+
22+
For 1D convolutions:
23+
MLX format: (out_channels, kernel_size, in_channels)
24+
PyTorch format: (out_channels, in_channels, kernel_size)
25+
26+
For 2D convolutions:
27+
MLX format: (out_channels, kH, kW, in_channels)
28+
PyTorch format: (out_channels, in_channels, kH, kW)
29+
30+
Returns True if the array appears to be in MLX format (no transpose needed).
31+
Returns False if the array appears to be in PyTorch format (needs transpose).
32+
33+
Heuristic: kernel dimensions are typically small (1, 3, 5, 7, 9, 11),
34+
while channel dimensions are typically larger (64, 128, 256, 512, etc.).
35+
"""
1936
shape = arr.shape
2037

21-
# Check if the shape has 4 dimensions
38+
# Common kernel sizes for convolutions
39+
KERNEL_SIZE_THRESHOLD = 15
40+
2241
if len(shape) == 4:
23-
out_channels, kH, KW, _ = shape
24-
# Check if out_channels is the largest, and kH and KW are the same
25-
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
42+
# 2D convolution: check if dims 1,2 are kernel-like (small) vs dim 3 being channel-like
43+
out_channels, dim1, dim2, dim3 = shape
44+
# MLX format: (out_channels, kH, kW, in_channels) - dim1, dim2 are small kernels
45+
# PyTorch format: (out_channels, in_channels, kH, kW) - dim3, dim2 are small kernels
46+
if dim1 <= KERNEL_SIZE_THRESHOLD and dim2 <= KERNEL_SIZE_THRESHOLD and dim3 > KERNEL_SIZE_THRESHOLD:
47+
return True # MLX format
48+
elif dim2 <= KERNEL_SIZE_THRESHOLD and dim3 <= KERNEL_SIZE_THRESHOLD and dim1 > KERNEL_SIZE_THRESHOLD:
49+
return False # PyTorch format
50+
# Fallback to original logic for ambiguous cases
51+
if (out_channels >= dim1) and (out_channels >= dim2) and (dim1 == dim2):
2652
return True
27-
else:
28-
return False
29-
# Check if the shape has 3 dimensions
53+
return False
54+
3055
elif len(shape) == 3:
31-
_, kW, out_channels = shape
32-
# Check if out_channels is the largest
33-
if kW >= out_channels:
34-
return True
35-
else:
36-
return False
56+
# 1D convolution: (out_channels, kernel_size, in_channels) for MLX
57+
# (out_channels, in_channels, kernel_size) for PyTorch
58+
out_channels, dim1, dim2 = shape
59+
# If middle dim is small (kernel-like) and last dim is large (channel-like): MLX format
60+
if dim1 <= KERNEL_SIZE_THRESHOLD and dim2 > KERNEL_SIZE_THRESHOLD:
61+
return True # MLX format
62+
# If last dim is small (kernel-like) and middle dim is large (channel-like): PyTorch format
63+
elif dim2 <= KERNEL_SIZE_THRESHOLD and dim1 > KERNEL_SIZE_THRESHOLD:
64+
return False # PyTorch format
65+
66+
# Ambiguous case: both dims are small (both could be kernel-like)
67+
# Special handling when one dim is 1:
68+
# - in_channels=1 is common for certain operations
69+
# - kernel_size=1 (pointwise conv) is less common than kernel_size=3,5,7
70+
# So if dim1=1 and dim2>1, assume dim1 is in_channels (PyTorch format)
71+
if dim1 == 1 and dim2 > 1:
72+
return False # Assume PyTorch format: (out, in=1, kernel)
73+
if dim2 == 1 and dim1 > 1:
74+
return True # Assume MLX format: (out, kernel, in=1)
75+
76+
# Both dims are similar and neither is 1
77+
# Kernel is typically smaller than or equal to in_channels
78+
if dim1 <= dim2:
79+
return True # Assume MLX format (kernel in middle is smaller or equal)
80+
return False # Assume PyTorch format
81+
3782
else:
3883
return False

mlx_audio/tts/models/base.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,51 @@ def from_dict(cls, params):
1919

2020

2121
def check_array_shape(arr):
22+
"""
23+
Check if a conv weight array is already in MLX format.
24+
25+
For 1D convolutions:
26+
MLX format: (out_channels, kernel_size, in_channels)
27+
PyTorch format: (out_channels, in_channels, kernel_size)
28+
29+
Returns True if the array appears to be in MLX format (no transpose needed).
30+
Returns False if the array appears to be in PyTorch format (needs transpose).
31+
32+
Heuristic: kernel dimensions are typically small (1, 3, 5, 7, 9, 11),
33+
while channel dimensions are typically larger (64, 128, 256, 512, etc.).
34+
"""
2235
shape = arr.shape
2336

24-
# Check if the shape has 4 dimensions
37+
# Common kernel sizes for convolutions
38+
KERNEL_SIZE_THRESHOLD = 15
39+
2540
if len(shape) != 3:
2641
return False
2742

28-
out_channels, kH, KW = shape
29-
30-
# Check if out_channels is the largest, and kH and KW are the same
31-
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
32-
return True
33-
else:
34-
return False
43+
out_channels, dim1, dim2 = shape
44+
45+
# If middle dim is small (kernel-like) and last dim is large (channel-like): MLX format
46+
if dim1 <= KERNEL_SIZE_THRESHOLD and dim2 > KERNEL_SIZE_THRESHOLD:
47+
return True # MLX format
48+
# If last dim is small (kernel-like) and middle dim is large (channel-like): PyTorch format
49+
elif dim2 <= KERNEL_SIZE_THRESHOLD and dim1 > KERNEL_SIZE_THRESHOLD:
50+
return False # PyTorch format
51+
52+
# Ambiguous case: both dims are small (both could be kernel-like)
53+
# Special handling when one dim is 1:
54+
# - in_channels=1 is common for certain operations
55+
# - kernel_size=1 (pointwise conv) is less common than kernel_size=3,5,7
56+
# So if dim1=1 and dim2>1, assume dim1 is in_channels (PyTorch format)
57+
if dim1 == 1 and dim2 > 1:
58+
return False # Assume PyTorch format: (out, in=1, kernel)
59+
if dim2 == 1 and dim1 > 1:
60+
return True # Assume MLX format: (out, kernel, in=1)
61+
62+
# Both dims are similar and neither is 1
63+
# Kernel is typically smaller than or equal to in_channels
64+
if dim1 <= dim2:
65+
return True # Assume MLX format (kernel in middle is smaller or equal)
66+
return False # Assume PyTorch format
3567

3668

3769
def adjust_speed(audio_array, speed_factor):

mlx_audio/tts/models/kokoro/istftnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,11 @@ def __call__(self, asr, F0_curve, N, s):
965965
def sanitize(self, key, weights):
966966
sanitized_weights = None
967967
if "noise_convs" in key and key.endswith(".weight"):
968-
sanitized_weights = weights.transpose(0, 2, 1)
968+
# Only transpose if in PyTorch format
969+
if check_array_shape(weights):
970+
sanitized_weights = weights
971+
else:
972+
sanitized_weights = weights.transpose(0, 2, 1)
969973

970974
elif "weight_v" in key:
971975
if check_array_shape(weights):

mlx_audio/tts/models/kokoro/kokoro.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,12 @@ def sanitize(self, weights):
209209
sanitized_weights[key] = state_dict
210210

211211
if key.startswith("predictor"):
212-
if "F0_proj.weight" in key:
213-
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
214-
215-
elif "N_proj.weight" in key:
216-
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
212+
if "F0_proj.weight" in key or "N_proj.weight" in key:
213+
# Only transpose if in PyTorch format
214+
if check_array_shape(state_dict):
215+
sanitized_weights[key] = state_dict
216+
else:
217+
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
217218

218219
elif "weight_v" in key:
219220
if check_array_shape(state_dict):

mlx_audio/tts/models/kokoro/pipeline.py

Lines changed: 167 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,21 @@ def __init__(
115115
raise
116116
elif lang_code == "z":
117117
try:
118-
from misaki import zh
119-
120-
self.g2p = zh.ZHG2P()
118+
from pypinyin import pinyin, Style
119+
120+
self.pinyin = pinyin
121+
self.pinyin_style = Style
122+
# Also initialize English G2P for mixed Chinese/English text
123+
try:
124+
self.en_g2p = en.G2P(trf=False, fallback=None, unk="")
125+
except Exception as e:
126+
logging.warning(f"English G2P not available for mixed text: {e}")
127+
self.en_g2p = None
128+
# Use a simple wrapper as g2p for compatibility
129+
self.g2p = lambda text: (self._chinese_to_bopomofo(text), None)
121130
except ImportError:
122131
logging.error(
123-
"You need to `pip install misaki[zh]` to use lang_code='z'"
132+
"You need to `pip install pypinyin` to use lang_code='z'"
124133
)
125134
raise
126135
else:
@@ -190,6 +199,154 @@ def load_voice(self, voice: str, delimiter: str = ",") -> mx.array:
190199
self.voices[voice] = mx.mean(mx.stack(packs), axis=0)
191200
return self.voices[voice]
192201

202+
def _number_to_chinese(self, num_str: str) -> str:
203+
"""Convert Arabic numerals to Chinese characters.
204+
205+
Examples:
206+
"23" -> "二十三"
207+
"100" -> "一百"
208+
"1000" -> "一千"
209+
"""
210+
digits = "零一二三四五六七八九"
211+
units = ["", "十", "百", "千"]
212+
big_units = ["", "万", "亿"]
213+
214+
if not num_str:
215+
return ""
216+
217+
# Handle decimal numbers
218+
if "." in num_str:
219+
integer_part, decimal_part = num_str.split(".", 1)
220+
integer_chinese = self._number_to_chinese(integer_part) if integer_part else ""
221+
decimal_chinese = "".join(digits[int(d)] for d in decimal_part)
222+
return f"{integer_chinese}{decimal_chinese}"
223+
224+
num = int(num_str)
225+
if num == 0:
226+
return "零"
227+
228+
if num < 0:
229+
return "负" + self._number_to_chinese(str(-num))
230+
231+
result = ""
232+
unit_index = 0
233+
234+
while num > 0:
235+
section = num % 10000
236+
if section > 0:
237+
section_str = ""
238+
for i, unit in enumerate(units):
239+
digit = section % 10
240+
section = section // 10
241+
if digit > 0:
242+
section_str = digits[digit] + unit + section_str
243+
elif section_str and not section_str.startswith("零"):
244+
section_str = "零" + section_str
245+
if section == 0:
246+
break
247+
result = section_str + big_units[unit_index] + result
248+
num = num // 10000
249+
unit_index += 1
250+
251+
# Special case: 10-19 don't need leading "一"
252+
if result.startswith("一十"):
253+
result = result[1:]
254+
255+
return result
256+
257+
def _chinese_to_bopomofo(self, text: str) -> str:
258+
"""Convert Chinese text to Bopomofo with numeric tones.
259+
260+
The Kokoro ZH model expects Bopomofo symbols with numeric tones (1-5).
261+
"""
262+
# Tone mark to number mapping
263+
tone_map = {
264+
"\u02ca": "2", # ˊ tone 2
265+
"\u02c7": "3", # ˇ tone 3
266+
"\u02cb": "4", # ˋ tone 4
267+
"\u02d9": "5", # ˙ neutral tone
268+
}
269+
270+
# First, convert numbers to Chinese characters
271+
# Match sequences of digits (including decimals)
272+
text = re.sub(
273+
r"(\d+\.?\d*)",
274+
lambda m: self._number_to_chinese(m.group(1)),
275+
text,
276+
)
277+
278+
result = []
279+
for char in text:
280+
# Chinese character range
281+
if "\u4e00" <= char <= "\u9fff":
282+
bpmf = self.pinyin(char, style=self.pinyin_style.BOPOMOFO)[0][0]
283+
284+
# Extract tone mark and convert to number
285+
tone = "1" # default tone 1
286+
clean_bpmf = ""
287+
for c in bpmf:
288+
if c in tone_map:
289+
tone = tone_map[c]
290+
else:
291+
clean_bpmf += c
292+
293+
result.append(clean_bpmf + tone)
294+
elif char.isascii() and char.isalpha():
295+
# English letters - will be processed separately
296+
result.append(char)
297+
else:
298+
# Punctuation and other characters
299+
result.append(char)
300+
301+
return " ".join(result)
302+
303+
def _process_mixed_zh_en(self, text: str) -> str:
304+
"""Process mixed Chinese/English text by using appropriate G2P for each part.
305+
306+
Args:
307+
text: Input text containing Chinese and/or English
308+
309+
Returns:
310+
Combined phoneme string with proper phonemes for both languages
311+
"""
312+
# Pattern to match English sequences (letters, spaces, and common punctuation)
313+
pattern = r"([a-zA-Z][a-zA-Z\s,.'\"!\?\-]*)"
314+
315+
parts = re.split(pattern, text)
316+
phonemes = []
317+
318+
for part in parts:
319+
if not part.strip():
320+
continue
321+
322+
# Check if this part starts with English letter
323+
if re.match(r"^[a-zA-Z]", part):
324+
# Process as English
325+
if self.en_g2p:
326+
try:
327+
_, tokens = self.en_g2p(part)
328+
ps = "".join(
329+
t.phonemes + (" " if t.whitespace else "")
330+
for t in tokens
331+
if t.phonemes
332+
)
333+
if ps.strip():
334+
phonemes.append(ps.strip())
335+
except Exception as e:
336+
logging.warning(f"English G2P failed for '{part}': {e}")
337+
# Keep English as-is if G2P fails
338+
phonemes.append(part.strip())
339+
else:
340+
# No English G2P available, keep as-is
341+
phonemes.append(part.strip())
342+
else:
343+
# Process as Chinese using Bopomofo
344+
ps = self._chinese_to_bopomofo(part)
345+
if ps.strip():
346+
phonemes.append(ps.strip())
347+
348+
return " ".join(phonemes)
349+
193350
@classmethod
194351
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
195352
return "".join(
@@ -470,7 +627,12 @@ def __call__(
470627
if not chunk.strip():
471628
continue
472629

473-
ps, _ = self.g2p(chunk)
630+
# For Chinese, use mixed language processing if English G2P is available
631+
if self.lang_code == "z" and hasattr(self, "en_g2p") and self.en_g2p:
632+
ps = self._process_mixed_zh_en(chunk)
633+
else:
634+
ps, _ = self.g2p(chunk)
635+
474636
if not ps:
475637
continue
476638
elif len(ps) > 510:

0 commit comments

Comments
 (0)