Skip to content

Commit 720a327

Browse files
committed
Parakeet android
1 parent ff76165 commit 720a327

3 files changed

Lines changed: 1085 additions & 1 deletion

File tree

extension/android/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ endif()
172172
if(EXECUTORCH_BUILD_LLAMA_JNI)
173173
target_sources(
174174
executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/jni_layer_asr.cpp
175-
jni/log.cpp
175+
jni/jni_layer_parakeet.cpp jni/log.cpp
176176
)
177177
list(APPEND link_libraries extension_llm_runner extension_asr_runner)
178178
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.parakeet
10+
11+
import java.io.Closeable
12+
import java.io.File
13+
import java.util.concurrent.atomic.AtomicLong
14+
import org.pytorch.executorch.annotations.Experimental
15+
16+
/**
17+
* ParakeetModule is a wrapper around the ExecuTorch Parakeet TDT Runner. It provides a simple
18+
* interface to transcribe audio from WAV files using the NVIDIA Parakeet TDT speech recognition
19+
* model.
20+
*
21+
* The module loads a WAV file, runs preprocessing (mel-spectrogram extraction), encoding, and TDT
22+
* greedy decoding to generate transcriptions with optional timestamps.
23+
*
24+
* Warning: These APIs are experimental and subject to change without notice
25+
*
26+
* @param modelPath Path to the ExecuTorch Parakeet model file (.pte). The model must expose
27+
* callable methods: "preprocessor", "encoder", "decoder_step", "joint", and metadata methods.
28+
* @param tokenizerPath Path to the SentencePiece tokenizer model file.
29+
* @param dataPath Optional path to additional data file (e.g., for delegate data like CUDA).
30+
*/
31+
@Experimental
32+
class ParakeetModule(
33+
modelPath: String,
34+
tokenizerPath: String,
35+
dataPath: String? = null,
36+
) : Closeable {
37+
38+
private val nativeHandle = AtomicLong(0L)
39+
40+
init {
41+
val modelFile = File(modelPath)
42+
require(modelFile.canRead() && modelFile.isFile) { "Cannot load model path $modelPath" }
43+
val tokenizerFile = File(tokenizerPath)
44+
require(tokenizerFile.exists()) { "Cannot load tokenizer path $tokenizerPath" }
45+
46+
val handle = nativeCreate(modelPath, tokenizerPath, dataPath)
47+
if (handle == 0L) {
48+
throw RuntimeException("Failed to create native ParakeetModule")
49+
}
50+
nativeHandle.set(handle)
51+
}
52+
53+
companion object {
54+
init {
55+
System.loadLibrary("executorch")
56+
}
57+
58+
@JvmStatic
59+
private external fun nativeCreate(
60+
modelPath: String,
61+
tokenizerPath: String,
62+
dataPath: String?,
63+
): Long
64+
65+
@JvmStatic private external fun nativeDestroy(nativeHandle: Long)
66+
67+
@JvmStatic
68+
private external fun nativeTranscribe(
69+
nativeHandle: Long,
70+
wavPath: String,
71+
timestamps: String,
72+
): String
73+
}
74+
75+
/** Check if the native handle is valid (not yet closed). */
76+
val isValid: Boolean
77+
get() = nativeHandle.get() != 0L
78+
79+
/** Releases native resources. Call this when done with the module. */
80+
override fun close() {
81+
val handle = nativeHandle.getAndSet(0L)
82+
if (handle != 0L) {
83+
nativeDestroy(handle)
84+
}
85+
}
86+
87+
/**
88+
* Transcribe audio from a WAV file.
89+
*
90+
* This is a blocking call that returns the complete transcription.
91+
*
92+
* @param wavPath Path to the WAV audio file (must be 16kHz mono)
93+
* @param timestamps Timestamp output mode: "none", "token", "word", "segment", or "all". Default
94+
* is "segment" which returns sentence-level timestamps.
95+
* @return The transcribed text, optionally with timestamps depending on the mode
96+
* @throws IllegalStateException if the module has been destroyed
97+
* @throws RuntimeException if transcription fails
98+
*/
99+
@JvmOverloads
100+
fun transcribe(
101+
wavPath: String,
102+
timestamps: String = "segment",
103+
): String {
104+
val handle = nativeHandle.get()
105+
check(handle != 0L) { "ParakeetModule has been destroyed" }
106+
val wavFile = File(wavPath)
107+
require(wavFile.canRead() && wavFile.isFile) { "Cannot read WAV file: $wavPath" }
108+
109+
return nativeTranscribe(handle, wavPath, timestamps)
110+
}
111+
}

0 commit comments

Comments
 (0)