Home Artificial Intelligence Whisper JAX vs PyTorch: Uncovering the Truth about ASR Performance on GPUs Introduction PyTorch vs. JAX Constructing the ARS System that uses PyTorch or JAX Whisper JAX vs. PyTorch Performance Comparison Conclusion

Whisper JAX vs PyTorch: Uncovering the Truth about ASR Performance on GPUs Introduction PyTorch vs. JAX Constructing the ARS System that uses PyTorch or JAX Whisper JAX vs. PyTorch Performance Comparison Conclusion

0
Whisper JAX vs PyTorch: Uncovering the Truth about ASR Performance on GPUs
Introduction
PyTorch vs. JAX
Constructing the ARS System that uses PyTorch or JAX
Whisper JAX vs. PyTorch Performance Comparison
Conclusion

On this planet of Automatic Speech Recognition (ASR), speed and accuracy are of great importance. The scale of the info and models has been growing substantially recently, making it hard to be efficient. Nonetheless, the race is just starting, and we see recent developments every week. In this text, we give attention to Whisper JAX, a recent implementation of Whisper using a special backend framework that seems to run 70 times faster than OpenAI’s PyTorch implementation. We tested each CPU and GPU implementations and measured accuracy and execution time. Also, we defined experiments for small and large-size models while parametrizing batch size and data types to see if we could improve it further.

As we saw in our previous article, Whisper is a flexible speech recognition model that excels in multiple speech-processing tasks. It might perform multilingual speech recognition, translation, and even voice activity detection. It uses a Transformer sequence-to-sequence architecture to predict words and tasks jointly. Whisper works as a meta-model for speech-processing tasks. One among the downsides of Whisper is its efficiency; it is commonly found to be fairly slow in comparison with other state-of-the-art models.

In the next sections, we undergo the small print of what modified with this recent approach. We compare Whisper and Whisper JAX, highlight the foremost differences between PyTorch and JAX, and develop a pipeline to judge the speed and accuracy between each implementations.

Figure 1: Can we make sense of sound efficiently? (source)

This text belongs to “Large Language Models Chronicles: Navigating the NLP Frontier”, a brand new weekly series of articles that can explore the way to leverage the facility of huge models for various NLP tasks. By diving into these cutting-edge technologies, we aim to empower developers, researchers, and enthusiasts to harness the potential of NLP and unlock recent possibilities.

Articles published to date:

  1. Summarizing the newest Spotify releases with ChatGPT
  2. Master Semantic Search at Scale: Index Hundreds of thousands of Documents with Lightning-Fast Inference Times using FAISS and Sentence Transformers
  3. Unlock the Power of Audio Data: Advanced Transcription and Diarization with Whisper, WhisperX, and PyAnnotate

As all the time, the code is offered on my Github.

The Machine Learning community extensively uses powerful libraries like PyTorch and JAX. While they share some similarities, their inner works are quite different. Let’s understand the foremost differences.

The AI Research Lab at Meta developed PyTorch and actively maintains it today. It’s an open-source library based on the Torch library. Researchers widely use PyTorch as a consequence of its dynamic computation graph, intuitive interface, and solid debugging capabilities. The undeniable fact that it uses dynamic graphs gives it greater flexibility in constructing recent models and simplifying the modification of such models during runtime. It’s closer to Python and specifically to the NumPy API. The foremost difference is that we aren’t working with arrays but with tensors, which might run on GPU, and supports auto differentiation.

JAX is a high-performance library developed by Google. Conversely to PyTorch, JAX combines the advantages of static and dynamic computation graphs. It does this through its just-in-time compilation feature, which provides flexibility and performance. We are able to consider JAX being a stack of interpreters that progressively rewrite your program. It will definitely offloads the actual computation to XLA — the Accelerated Linear Algebra compiler, also designed and developed by Google, to speed up Machine Learning computations.

Let’s start by constructing a category to handle audio transcriptions using Whisper with PyTorch (OpenAI’s implementation) or Whisper with JAX. Our class is a wrapper for the models and an interface to simply arrange experiments. We would like to perform several experiments, including specifying the device, model type, and extra hyperparameters for Whisper JAX. Note that we used a singleton pattern to be certain that as we run several experiences, we don’t find yourself with several instances of the model consuming our memory.

class Transcription:
"""
A category to handle audio transcriptions using either the Whisper or Whisper JAX model.

Attributes:
audio_file_path (str): Path to the audio file to transcribe.
model_type (str): The style of model to make use of for transcription, either "whisper" or "whisper_jax".
device (str): The device to make use of for inference (e.g., "cpu" or "cuda").
model_name (str): The precise model to make use of (e.g., "base", "medium", "large", or "large-v2").
dtype (Optional[str]): The information type to make use of for Whisper JAX, either "bfloat16" or "bfloat32".
batch_size (Optional[int]): The batch size to make use of for Whisper JAX.
"""
_instance = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(
self,
audio_file_path: str,
model_type: str = "whisper",
device: str = "cpu",
model_name: str = "base",
dtype: Optional[str] = None,
batch_size: Optional[int] = None,
):
self.audio_file_path = audio_file_path
self.device = device
self.model_type = model_type
self.model_name = model_name
self.dtype = dtype
self.batch_size = batch_size
self.pipeline = None

The set_pipeline method sets up the pipeline for the required model type. Depending on the worth of the model_type attribute, the strategy initializes the pipeline using either by instantiating the FlaxWhisperPipline class for Whisper JAX or by calling the whisper.load_model() function for the PyTorch implementation of Whisper.

    def set_pipeline(self) -> None:
"""
Arrange the pipeline for the required model type.

Returns:
None
"""
if self.model_type == "whisper_jax":
pipeline_kwargs = {}
if self.dtype:
pipeline_kwargs["dtype"] = getattr(jnp, self.dtype)
if self.batch_size:
pipeline_kwargs["batch_size"] = self.batch_size

self.pipeline = FlaxWhisperPipline(
f"openai/whisper-{self.model_name}", **pipeline_kwargs
)
elif self.model_type == "whisper":
self.pipeline = whisper.load_model(
self.model_name,
torch.device("cuda:0") if self.device == "gpu" else self.device,
)
else:
raise ValueError(f"Invalid model type: {self.model_type}")

The run_pipeline method transcribes the audio file and returns the outcomes as an inventory of dictionaries containing the transcribed text and timestamps. Within the case of Whisper JAX, it considers optional parameters like data type and batch size, if provided. Notice that you may set return_timestampsto False if you happen to are only fascinated by getting the transcription. The model output is different if we run the transcription process with the PyTorch implementation. Thus, we must create a brand new object that aligns each return objects.

    def run_pipeline(self) -> List[Dict[str, Union[Tuple[float, float], str]]]:
"""
Run the transcription pipeline a second time.

Returns:
A listing of dictionaries, each containing text and a tuple of start and end timestamps.
"""
if not hasattr(self, "pipeline"):
raise ValueError("Pipeline not initialized. Call set_pipeline() first.")

if self.model_type == "whisper_jax":
outputs = self.pipeline(
self.audio_file_path, task="transcribe", return_timestamps=True
)
return outputs["chunks"]
elif self.model_type == "whisper":
result = self.pipeline.transcribe(self.audio_file_path)
formatted_result = [
{
"timestamp": (segment["start"], segment["end"]),
"text": segment["text"],
}
for segment in result["segments"]
]
return formatted_result
else:
raise ValueError(f"Invalid model type: {self.model_type}")

Finally, the transcribe_multiple() method enables the transcription of multiple audio files. It takes an inventory of audio file paths and returns an inventory of transcriptions for every audio file, where each transcription is an inventory of dictionaries containing text and a tuple of start and end timestamps.

    def transcribe_multiple(
self, audio_file_paths: List[str]
) -> List[List[Dict[str, Union[Tuple[float, float], str]]]]:
"""
Transcribe multiple audio files using the required model type.

Args:
audio_file_paths (List[str]): A listing of audio file paths to transcribe.

Returns:
List[List[Dict[str, Union[Tuple[float, float], str]]]]: A listing of transcriptions for every audio file, where each transcription is an inventory of dictionaries containing text and a tuple of start and end timestamps.
"""
transcriptions = []

for audio_file_path in audio_file_paths:
self.audio_file_path = audio_file_path
self.set_pipeline()
transcription = self.run_pipeline()

transcriptions.append(transcription)

return transcriptions

Experimental Setup

We used an extended audio clip with greater than half-hour to judge the performance of Whisper variants, with a PyTorch and JAX implementation. The researchers that developed Whisper JAX claim that the difference is more significant when transcribing long audio files.

Our experimental hardware setup consists of the next key components. For the CPU, we now have an x86_64 architecture with a complete of 112 cores, powered by an Intel(R) Xeon(R) Gold 6258R CPU running at 2.70GHz. Regarding GPU, we use an NVIDIA Quadro RTX 8000 with 48 GB of VRAM.

Results and Discussion

On this section, we discuss the outcomes obtained from the experiments to check the performance of Whisper JAX and PyTorch implementations. Our results provide insights into the speed and efficiency of the 2 implementations on each GPU and CPU platforms.

Our first experiment involved running an extended audio (over half-hour) using GPU and the larger Whisper model (large-v2) that requires roughly 10GB of VRAM. Contrary to the claim made by the authors of Whisper JAX, our results indicate that the JAX implementation is slower than the PyTorch version. Even with the incorporation of half-precision and batching, we couldn’t surpass the performance of the PyTorch implementation using Whisper JAX. Whisper JAX took almost twice the time in comparison with the PyTorch implementation to perform the same transcription. We also observed an unusually long transcription time when each half-precision and batching were employed.

Figure 2: Transcription execution time using Whisper’s PyTorch implementation against Whisper JAX in GPU for the massive model (image by creator)

However, when comparing the CPU performance, our results show that Whisper JAX outperforms the PyTorch implementation. The speedup factor was roughly two times faster for Whisper JAX in comparison with the PyTorch version. We observed this pattern for the bottom and significant model variations.

Figure 3: Transcription execution time using Whisper’s PyTorch implementation against Whisper JAX for the bottom and huge model in CPU (image by creator)

Regarding the claim made by the authors of Whisper JAX that the second transcription needs to be much faster, our experiments didn’t provide supporting evidence. The difference in speed between the primary and second transcriptions was not significant. Plus, we found that the pattern was similar between each Whisper and Whisper JAX implementations.

In this text, we presented a comprehensive evaluation of the Whisper JAX implementation, comparing its performance to the unique PyTorch implementation of Whisper. Our experiments aimed to judge the claimed 70x speed improvement using a wide range of setups, including different hardware and hyperparameters for the Whisper JAX model.

The outcomes showed that Whisper JAX outperformed the PyTorch implementation on CPU platforms, with a speedup factor of roughly two fold. Nonetheless, our experiments didn’t support the authors’ claims that Whisper JAX is significantly faster on GPU platforms. Actually, the PyTorch implementation performed higher when transcribing long audio files using a GPU.

Moreover, we found no significant difference within the speed between the primary and second transcriptions, a claim made by the Whisper JAX authors. Each implementations exhibited the same pattern on this regard.

Communicate: LinkedIn

LEAVE A REPLY

Please enter your comment!
Please enter your name here