# Requirements
If you are running this tutorial on collab we strongly suggest switching to a GPU environment - especially for the deep learning stage.

If you are running it locally and have no access to a dedicated and/or compatible GPU, or are out of resources on collab, worry not! Simply run the experiments for less epochs so that you can follow along!


In [1]:
!pip install datasets PyWavelets speechpy

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting PyWavelets
  Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.0 kB)
Collecting speechpy
  Downloading speechpy-2.4-py2.py3-none-any.whl.metadata (407 bytes)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [2]:
import os

import librosa
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pywt
import scipy.io as sio
import scipy.signal
import random
import re

from scipy.io import wavfile

from PIL import Image
from tqdm import tqdm


# Part 1. Signal Processing

You have access to processed samples from the [2022 CirCor DigiScope dataset](https://physionet.org/content/circor-heart-sound/1.0.3/) available on the [HuggingFace Dataset Hub](https://huggingface.co/docs/hub/datasets-overview).

Please click [this link](https://huggingface.co/datasets/miguellmartins/circor-digiscope-physionet22-processed) and inspect the attributes of the original (raw) dataset. This HF Dataset has two splits: "original" and "processed".

The "processed" version contains the resulting sounds after filtering and denoising (using the techniques we will discuss in this section).

These data may be interfaced using the [HuggingFace API](https://huggingface.co/docs/). The actual files are stored locally and remotely using [Apache Parquet](https://parquet.apache.org/).


In [3]:
from datasets import load_dataset, Audio, DatasetDict

In [None]:
circor = load_dataset('miguellmartins/circor-digiscope-physionet22-processed')

Inspect the attributes of the datasets. Store splits in sepearate objects

In [5]:
print(circor)
original = circor['original']
processed = circor['processed']

DatasetDict({
    original: Dataset({
        features: ['filename', 'recording', 'recording_label', 'heart_state_labels'],
        num_rows: 3363
    })
    processed: Dataset({
        features: ['filename', 'recording', 'recording_label', 'heart_state_labels'],
        num_rows: 3363
    })
})


In [6]:
# You can check the name of the original wav file, the waveform,
# and the sampling rate of each recording
original[0]['recording']

{'path': '13918_AV.wav',
 'array': array([-0.0100708 , -0.00579834, -0.00692749, ..., -0.00238037,
         0.00396729,  0.00717163]),
 'sampling_rate': 4000}

## 1.1 Pick a random sound and use the below function to visualize the sound and its annotation

Remember that each heart state is codified with categorical labels. So
* 1: S1
* 2: Systole
* 3: S2
* 4: Diastole


In [None]:
def plot_sound_and_label(x, y, sampling_rate=4000):
    assert len(x) == len(y)
    number_of_samples = len(x)
    # Time (duration) = T_i / sample_rate
    time = np.arange(number_of_samples) / sampling_rate

    # Plotting x and y together
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Plot x on the primary y-axis
    ax1.plot(time, x, label='x (PCG)', color='b')
    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('Amplitude of o', color='b')
    ax1.tick_params(axis='y', labelcolor='b')

    # Create a secondary y-axis for y
    ax2 = ax1.twinx()
    ax2.step(time, y, label='y (Heart Sattes)', color='r', where='post', linewidth=2)
    ax2.set_ylabel('y (Labels)', color='r')
    ax2.set_yticks([1, 2, 3, 4])
    ax2.tick_params(axis='y', labelcolor='r')

    fig.suptitle('PCG Amplitude and Heart Sound Labels')
    ax1.legend(loc='upper left')
    ax2.legend(loc='upper right')
    plt.show()

In [None]:
sample_idx =   # chose an idx i and visualize it


plot_sound_and_label(x=original[sample_idx]['recording']['array'],
                     y=original[sample_idx]['heart_state_labels'],
                     sampling_rate=original[sample_idx]['recording']['sampling_rate'])

# 2.1.1 Band-pass filtering
The information about the heart typically is assumed to be in [the [25-400] Hz band](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10242001).

We are specifically using [Butterworth filters](https://en.wikipedia.org/wiki/Butterworth_filter) for this purpose.

Run the cells below to visualize these filters.

Apply the band-pass filters to a sound of your choice from the dataset.

In [None]:
def plot_filter_responses(low_pass_fs, high_pass_fs, sampling_rate=4000, filter_order=2):
    """
    Plots the frequency responses of a highpass and a lowpass Butterworth filter.

    Parameters:
    - low_pass_fs: float, the cutoff frequency for the highpass filter (Hz)
    - high_pass_fs: float, the cutoff frequency for the lowpass filter (Hz)
    - sampling_rate: float, the sampling rate of the signals (Hz)
    - filter_order: int, the order of the Butterworth filter (default is 2)
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.signal import butter, sosfreqz

    # Design the highpass filter
    sos_hp = butter(N=filter_order, Wn=low_pass_fs, btype='highpass', analog=False, fs=sampling_rate, output='sos')
    # Design the lowpass filter
    sos_lp = butter(N=filter_order, Wn=high_pass_fs, btype='lowpass', analog=False, fs=sampling_rate, output='sos')

    # Frequency response for the highpass filter
    w_hp, h_hp = sosfreqz(sos_hp, fs=sampling_rate)
    # Frequency response for the lowpass filter
    w_lp, h_lp = sosfreqz(sos_lp, fs=sampling_rate)

    # Plot the frequency response of both filters
    plt.figure(figsize=(12, 6))

    # Plot for the highpass filter
    plt.subplot(2, 1, 1)
    plt.plot(w_hp, 20 * np.log10(np.abs(h_hp)), label='Highpass Filter')
    plt.title('Highpass Filter Frequency Response')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Amplitude (dB)')
    plt.grid()
    plt.axvline(low_pass_fs, color='red', linestyle='--', label=f'Cutoff: {low_pass_fs} Hz')
    plt.legend()

    # Plot for the lowpass filter
    plt.subplot(2, 1, 2)
    plt.plot(w_lp, 20 * np.log10(np.abs(h_lp)), label='Lowpass Filter')
    plt.title('Lowpass Filter Frequency Response')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Amplitude (dB)')
    plt.grid()
    plt.axvline(high_pass_fs, color='red', linestyle='--', label=f'Cutoff: {high_pass_fs} Hz')
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_filter_responses(low_pass_fs=..., high_pass_fs=...)

In [None]:
sample_idx = # choose a sample idx
sound = original[sample_idx]['recording']['array']
sampling_rate = original[sample_idx]['recording']['sampling_rate']
sos_hp = scipy.signal.butter(N=2, Wn=45, btype='highpass', analog=False, fs=sampling_rate,
                                     output='sos')
sos_lp = scipy.signal.butter(N=2, Wn=400, btype='lowpass', analog=False, fs=sampling_rate,
                              output='sos')
filtered = scipy.signal.sosfilt(sos_hp, sound)
filtered = scipy.signal.sosfilt(sos_lp, filtered)

## 2.1.2 Denoising -  Averaging Theory
The section of denoising in this tutorial follows [Messer et al.](https://www.sciencedirect.com/science/article/pii/S0026269201000957).


Suppose a Source $\mathbf{S}$ is corrupted additively by i.i.d. Gaussian noise $\epsilon_i$. We observe $\mathbf{X}$ and not $\mathbf{S}$ in our measurements such that:

$$X_i = S + \epsilon_i, i=1,\ldots, N$$

If one computes its variance:

$$\text{Var}\left(\mathbf{S} + \frac{1}{N} \sum_{j=1}^{N}  \epsilon_j\right) = \text{Var}(\mathbf{S}) + \text{Var}\left(\frac{1}{N} \sum_{j=1}^{N}  \epsilon_j \right) = \text{Var}(\mathbf{S}) + \frac{T_i\sigma^2}{N^2}= \text{Var}(\mathbf{S}) + \frac{\sigma^2}{N}$$

one observes that the **standard deviation** of the random terms will shrink as $\frac{\sigma}{\sqrt(N)}$.

Due to the periodic nature of heart sounds and their stationarity (at least in a "short" period of time), we can think of $\mathbf{S}$ to be the expected waveform of the heart cycle in a recording, i.e. the *characteristic heart cycle* of a patient.

## 2.1.3 Wavevet Denoising

In practice, the analysis of a characteristic heartbeat for most downstream
applications is not very useful. Even under (quasi)-stationary assumptions, there may be other phenomena present in the signal such as murmurs that may occur in all states of the heart cycles and across several frequency bands. These phenomena may also be **transient** which immediately defeats the purpose of a characteristic heartbeat.



The Wavelet decomposition will allows us to filter the original signal in an adaptitive way. It allows one to make a trade-off between frequency and time resolution as a function of scale.

The Discrete Wavelet Transform is given by:
$$ \text{DWT}_x^{\psi}(m, n) = \sum_{t} x(t) \psi^*_{m, n}(t) = \sum_{t} x(t) \psi^*\left(\frac{t - n 2^m}{2^m}\right)$$ using a precision of $2^{-m}$, or $m$ bits. $\psi^*$ is the so-called *mother Wavelet function*.

The following code provides visualization of the Haar Wavelet decomposition using 5 levels, i.e. a floating point precision of 5.

We will use the sound you have previously filtered with the butterworth filters for illustrative purposes.





In [None]:

n_levels = 5
coeffs = pywt.wavedec(filtered,
                      wavelet='haar',
                      level=n_levels)
# Plot the original signal alongside the wavelet coefficients
plt.figure(figsize=(12, 12))

# Plot the original noisy signal
plt.subplot(n_levels + 2, 1, 1)
plt.plot(filtered,
         color='green')
plt.title('Original Noisy Signal')
plt.grid(True)

# Plot the approximation coefficients at the highest level
plt.subplot(n_levels + 2, 1, 2)
plt.plot(coeffs[0], color='blue')
plt.title(f'Haar Approximation Coefficients (Level {n_levels})')
plt.grid(True)

# Plot the detail coefficients for each level
for i in range(1, n_levels + 1):
    plt.subplot(n_levels + 2, 1, i + 2)
    plt.plot(coeffs[i], color='red')
    plt.title(f'Haar Detail Coefficients (Level {n_levels - i + 1})')
    plt.grid(True)

plt.tight_layout()
plt.show()


## 2.1.3.a - Universal Thresholding
Since the DWT provides a multiresolution decomposition of the signal, we can use our **prior knowledge** that *most of the information is concentrated in the low frequencies* to mitigate noise in an adaptative fashion. Rembember, we may not want to get rid of all high frequency content necesseraly!



Suppose that the detail coefficients at the **finest scale** are distributed according to a standard Gaussian scaled by $\sigma$. Then, using [extreme value theory](https://nobel.web.unc.edu/wp-content/uploads/sites/13591/2019/11/Gaussian_Extremes-1.pdf), the largest detail coefficient is:

$$\max_{i=1, \ldots, n} |D_i| \approx O(\sigma\sqrt{2\log n})$$

The higher the frequency, the more sensitive to small perturbations our estimates will be, hence we are looking for a robust estimator of $\sigma$. Typically one uses the **median absolute deviation** estimate of the Gaussian:

$$\sigma \approx \frac{\text{median}(|D|)}{0.6745}$$



Now, we only need to define the thresholding function. We will implement the following soft-thresholding:

$$\hat{D}_j = \text{sign}(D_j) \cdot \max(|D_j| - \lambda, 0)$$

where $\lambda = \sigma\sqrt{2\log n}$. This effectively zeroes out coefficients smaller or equal to $\lambda$ and shifts the remaining $D_j$s towards 0 by $\lambda$.


Implement a function that receives the recordings and outputs the a filtered version of the signal. Use `PyWavelets` package to do so, using `pywt.wavedec`to decompose the signal and `pywt.waverec` to reconstruct the signal (after applying universal thresholding to the coefficients).

Run the cell below to see an example

In [None]:
def soft_threshold(coeff, threshold):
        return np.sign(coeff) * np.maximum(np.abs(coeff) - threshold, 0)

n_levels = 5
coeffs = pywt.wavedec(filtered, wavelet='haar', level=n_levels)
# Plot the original signal alongside the wavelet coefficients
plt.figure(figsize=(12, 12))

# Define a threshold value (e.g., universal threshold)
sigma = np.median(np.abs(coeffs[-1])) / 0.6745  # Estimating noise level
threshold = sigma * np.sqrt(2 * np.log(len(filtered)))

# Apply soft thresholding to the detail coefficients
coeffs_thresholded = [coeffs[0]]  # Keep approximation coefficients unchanged
for coeff in coeffs[1:]:
    coeffs_thresholded.append(soft_threshold(coeff, threshold))


# Plot the original noisy signal
plt.subplot(n_levels + 2, 1, 1)
plt.plot(filtered, color='green')
plt.title('Original Noisy Signal')
plt.grid(True)

# Plot the approximation coefficients at the highest level
plt.subplot(n_levels + 2, 1, 2)
plt.plot(coeffs[0], color='blue')
plt.title(f'Haar Approximation Coefficients (Level {n_levels})')
plt.grid(True)

# Plot the detail coefficients for each level
for i in range(1, n_levels + 1):
    plt.subplot(n_levels + 2, 1, i + 2)
    plt.plot(coeffs[i], color='red', label='original')
    plt.plot(coeffs_thresholded[i], color='purple', label='after universal threshold')
    plt.title(f'Haar Detail Coefficients (Level {n_levels - i + 1})')
    plt.grid(True)
    plt.legend()

plt.tight_layout()
plt.show()


## 2.2 - Visualize the original vs processed waveforms
Run the cell below to see how our signal-processing strategy affects the waveform.

In [None]:
def plot_denoised_signal(dataset, denoised_sounds, sample_idx):
    noisy_signal = dataset[sample_idx]['recording']['array']
    denoised_signal = denoised_sounds[sample_idx]['recording']['array']
    plt.figure(figsize=(10, 6))
    plt.plot(noisy_signal, label='Noisy Signal')
    plt.plot(denoised_signal, label='Denoised Signal', linewidth=2)
    plt.legend()
    plt.title(f'Signal denoising after denosing')
    plt.show()

# Inspect the results on a sample
sample_idx = # choose sample idx
plot_denoised_signal(original, processed, sample_idx)

# Part 2 - Deep Learning and Model serving

We have prepared a set of files with a set of features pre-extracted from the processed dataset.

Specifically, we extracted amplitude and homomorphic envelograms and subsample the signals to from 4000 Hz to 50 Hz.

Details can be found on most papers in the literature such as [this one](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10242001).

In [None]:
dataset_dict = load_dataset("miguellmartins/circor-digiscope-physionet22-tutorial")

In [None]:
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['filename', 'recording', 'recording_label', 'heart_state_labels', 'amplitude_env', 'homomorphic_env', 'identifier'],
        num_rows: 2683
    })
    dev: Dataset({
        features: ['filename', 'recording', 'recording_label', 'heart_state_labels', 'amplitude_env', 'homomorphic_env', 'identifier'],
        num_rows: 337
    })
    val: Dataset({
        features: ['filename', 'recording', 'recording_label', 'heart_state_labels', 'amplitude_env', 'homomorphic_env', 'identifier'],
        num_rows: 343
    })
})


We will be using a patch size $P$ of samples, so for a sample rate of 50Hz, the window spans roughly 1.3 seconds.

We will processing the sounds patch-by-patch, so we will be discarding recordings that do not span at least one patch $P$.

Run the following code to apply these steps to each dataset.

Note that if you have more than 2 envelograms from Tutorial 1 you need to change `nch` (i.e., number of channels) accordingly.

We recommend starting with `stride=32`, but you can revisit this part of the tutorial later and adapt all parameters to your liking.

In [None]:
PATCH_SIZE = 64
NUMBER_CHANNELS = 2
NUMBER_CLASSES = 4
STRIDE = 32
BATCH_SIZE = 32

Some of the sounds might be too short for the above configuration. We will have to preemptively filter them before we move forward with training.

In [None]:
def filter_datasets(dataset_dict: DatasetDict, patch_size: int = 64):
  _filter_small = lambda x: len(x) >= patch_size
  for split in dataset_dict:
    # We only to feature by one of the envelopes since they both have same length
    dataset_dict[split] = dataset_dict[split].filter(lambda x: len(x['amplitude_env']) >= patch_size)
  return dataset_dict


dataset_dict = filter_datasets(dataset_dict)

Filter:   0%|          | 0/2683 [00:00<?, ? examples/s]

Filter:   0%|          | 0/337 [00:00<?, ? examples/s]

Filter:   0%|          | 0/343 [00:00<?, ? examples/s]

The heart state labels are also ordinal. We will be changing the enconding to one-hot.

In [None]:
def one_hot_encoding(num_classes: int = 4):
  one_hot_labels = np.eye(num_classes)
  def _one_hot_encoding(example):
    example["heart_state_labels"] = one_hot_labels[np.array(example['heart_state_labels']) - 1]
    return example
  return _one_hot_encoding


for split in dataset_dict:
  dataset_dict[split] = dataset_dict[split].map(one_hot_encoding(num_classes=NUMBER_CLASSES))

The U-Net takes as input patches of a given size $P$.
Note that our sounds are downsampled to 50Hz by now (80$\times$ smaller than the original 4KHz). However, loading all sound patches, especially if they overlap, may have unrealistic V-RAM requirements for most scenarios.

With that in mind, we will have to compute our sound patches dynamically, i.e. online during training or inference.

We make use of the [HuggingFace Dataset API](https://huggingface.co/docs/hub/datasets-overview) to build a streamable dataset that can than be instantiated as a [TensorFlow Dataset](https://www.tensorflow.org/guide/data) to serve our U-Net.
Please inspect the following class `PatchIterableDataset` which will handle all the [ETL](https://tinyurl.com/527fak67) for our deep learning data pipeline.

Generators use the yield statement to produce a series of values, "pausing" the function each time a yield is encountered and resuming the execution in the next iteration. This makes them memory-efficient because they only produce items as needed, unlike lists that store all items in memory. Naturally, this solution is slower than pre-computing all patches.

In [None]:
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader

class PatchIterableDataset(IterableDataset):
    """
    Iterates over a Hugging Face dataset, chunking frames into (patch_size, 2).
    Does NOT store the entire dataset in memory, only a buffer of length patch_size.
    """
    def __init__(self, hf_dataset,
                 patch_size: int,
                 stride: int):
        """
        Args:
            hf_dataset: A Hugging Face (Iterable)Dataset or anything else iterable
                        where each item is a dict with 'features' -> shape (T_i, 2).
            patch_size: Number of frames to accumulate before yielding a chunk.
            drop_incomplete_chunk: If True, discard leftover frames if they're
                                   less than patch_size at the end.
        """
        super().__init__()
        self.dataset = hf_dataset
        self.patch_size = patch_size
        self.stride = stride

    def __iter__(self):
        buffer = []
        # Stream over the original dataset, example by example:
        for example in self.dataset:
            amp_env, homo_env, label = example['amplitude_env'], example['homomorphic_env'], example['heart_state_labels']
            amp_env = np.array(amp_env)
            homo_env = np.array(homo_env)
            label = np.array(label)
                # Combine the two feature columns along a new dimension => shape: (time, 2)
            num_samples = len(homo_env)
            sound = np.stack([amp_env, homo_env], axis=-1)
            num_windows = int((num_samples - self.patch_size) / self.stride) + 1
            for window_idx in range(num_windows):
                patch_start = window_idx * self.stride
                yield sound[patch_start:patch_start + self.patch_size, :], label[patch_start: patch_start + self.patch_size, :]

            window_remain = num_samples - self.patch_size
            if window_remain % self.stride > 0:
                yield sound[window_remain:, :], label[window_remain:, :]

In [None]:
patch_dataset_train = PatchIterableDataset(dataset_dict['train'],
                                           patch_size=PATCH_SIZE,
                                           stride=STRIDE)

patch_dataset_dev = PatchIterableDataset(dataset_dict['dev'],
                                           patch_size=PATCH_SIZE,
                                           stride=STRIDE)

patch_dataset_val = PatchIterableDataset(dataset_dict['val'],
                                           patch_size=PATCH_SIZE,
                                           stride=STRIDE)

In [None]:
import tensorflow as tf

def get_tf_dataset(patch_dataset, number_channels, number_classes, patch_size, batch_size, cache=False):
  gen_fn = lambda: ((x, y) for (x,y) in patch_dataset)
  tf_ds = tf.data.Dataset.from_generator(
      generator=gen_fn,
      output_signature=(
          tf.TensorSpec(shape=(patch_size, number_channels), dtype=tf.float32),    # sound shape: (time, 2)
          tf.TensorSpec(shape=(patch_size, number_classes), dtype=tf.float32)       # label shape: (time,)
      )
  )
  if cache:
    tf_ds = tf_ds.cache()
  tf_ds = tf_ds.batch(batch_size)
  tf_ds = tf_ds.prefetch(tf.data.AUTOTUNE)
  return tf_ds


Run the following code to create the Tensorflow Dataset objects using our custom data generator. We will be using `.cache()` and `.prefetch()` from the [tf.Data API](https://www.tensorflow.org/guide/data_performance) to minimize the overhead of computing each batch of data online.

In [None]:
train = get_tf_dataset(patch_dataset=patch_dataset_train,
                       number_channels=NUMBER_CHANNELS,
                       number_classes=NUMBER_CLASSES,
                       patch_size=PATCH_SIZE,
                       batch_size=BATCH_SIZE,
                       cache=True) # you could cache all dataframes but you need more compute for that

dev = get_tf_dataset(patch_dataset=patch_dataset_dev,
                       number_channels=NUMBER_CHANNELS,
                       number_classes=NUMBER_CLASSES,
                       patch_size=PATCH_SIZE,
                       batch_size=BATCH_SIZE)

val = get_tf_dataset(patch_dataset=patch_dataset_val,
                       number_channels=NUMBER_CHANNELS,
                       number_classes=NUMBER_CLASSES,
                       patch_size=PATCH_SIZE,
                       batch_size=BATCH_SIZE)

Introduced by [Ronneberg et al.](https://arxiv.org/abs/1505.04597) in 2015, the U-Net is ubiquitous in biomedical signal and image processing tasks.

Although its efficacy is a result of several advancements of modern deep learning and optimization techniques, one of its characteristic architetural patterns are the skip connections from the encoder to the decoder.

The idea is that low-resolution information is important for medical domains, which can be complementary to the semantical rich features in the decoder.

Inspect and run the code below of a template for a simple U-Net adapted to process signals instead of images; using 1D instead of 2D primitives.

In [None]:
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Dropout, UpSampling1D, concatenate

def example_unet(patch_size, nch, dropout=0.0):
    inputs = tf.keras.layers.Input(shape=(patch_size, nch))
    conv1 = tf.keras.layers.Conv1D(8, 3, activation='relu', padding='same')(inputs)
    conv1 = tf.keras.layers.Conv1D(8, 3, activation='relu', padding='same')(conv1)
    pool1 = tf.keras.layers.MaxPooling1D(pool_size=2)(conv1)
    pool1 = tf.keras.layers.Dropout(dropout)(pool1)

    conv2 = tf.keras.layers.Conv1D(16, 3, activation='relu', padding='same')(pool1)
    conv2 = tf.keras.layers.Conv1D(16, 3, activation='relu', padding='same')(conv2)

    up_prep = tf.keras.layers.UpSampling1D(size=2)(conv2)

    up = tf.keras.layers.concatenate([tf.keras.layers.Conv1D(8, 2, padding='same')(up_prep), conv1], axis=2)
    up = tf.keras.layers.Dropout(dropout)(up)
    convout = tf.keras.layers.Conv1D(8, 3, activation='relu', padding='same')(up)
    convout = tf.keras.layers.Conv1D(8, 3, activation='relu', padding='same')(convout)

    output_layer = tf.keras.layers.Conv1D(4, 1, activation='softmax')(convout)

    model = tf.keras.Model(inputs=[inputs], outputs=[output_layer])
    return model


In [None]:
model = example_unet(PATCH_SIZE, NUMBER_CHANNELS)
print(model.summary())

Instantiate your U-Net model and select a set of adequate metrics to track.

Early stopping on validation set is performed using `ModelCheckpoint` on `val_loss`.

We will be using the Adam Gradient Descent algorithm.
If you have access to a GPU run this code for 50 epochs.
If not, 10 epochs should still give you a sufficiently capable model to complete this tutorial

In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint

EPOCHS = 50 # your epochs
learning_rate = 1e-4  # your lr
model =  example_unet(PATCH_SIZE, NUMBER_CHANNELS)  # use your unet_here
checkpoint_path = 'unet_weights/unet.keras'
model.compile(optimizer=Adam(learning_rate=learning_rate), loss='categorical_crossentropy',
                  metrics=['categorical_accuracy', 'precision', 'recall'])
model_checkpoint = ModelCheckpoint(filepath=checkpoint_path,
                                   monitor='val_loss',
                                   save_best_only=True)

In [None]:
history = model.fit(train,
                    validation_data=dev,
                    epochs=EPOCHS,
                    verbose=1,
                    shuffle=True,
                    callbacks=[model_checkpoint])

model.load_weights(checkpoint_path)

Let us perform inference on the external validation dataset.

In [None]:
predictions_train = model.predict(train)
predictions_dev = model.predict(dev)
predictions_test = model.predict(val)

Expected: ['keras_tensor_28']
Received: inputs=Tensor(shape=(32, 64, 2))


[1m1460/1460[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step
[1m184/184[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 25ms/step
[1m167/167[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 24ms/step


`predictions_test` has the predictions on the test-set patch-wise in a patient agnostic fashion.

We now need to associate the appropriate outputs for each recording in an adequate fashion. Moreover, the fact that we used an overalpping slidding window means that we will have several heart state estimates for a given time $t$.

We will be averaging the U-Net's predictions across the patch dimension. The strategy resembles the `__iter__` method of `PatchIterableDataset`. Let $N_{P_i} = \lceil  \frac{T_i - P}{\tau} \rceil$

$$\tilde{\mathbf{y}}_i(t, \tau) = \frac{1}{N_{P_i}}\sum_{t=1}^{N_{P_i}} \text{U-Net}(\mathbf{x}_i(t, \tau)| \theta^*)$$
So $\tilde{\mathbf{y}}_i$ should have shape $(T_i, 4)$.

Inspect the function `process_unet_predictions`. It should return the predicted label sequence (in ordinal form, not one-hot-encoded), and the corresponding probabilty estimates.

In [None]:
def process_unet_predictions(preds, dataframe, patch_size, stride):
  num_observations = len(dataframe)
  output_probs = np.ndarray(shape=(num_observations), dtype=np.ndarray)
  output_seqs = np.ndarray(shape=(num_observations), dtype=np.ndarray)
  preds_idx = 0
  for idx, sample in tqdm(enumerate(dataframe)):
    sound = sample['heart_state_labels']
    sound_duration = len(sound)
    # number of patches associated to this sound
    number_patches = int((sound_duration - patch_size) / stride)
    if (sound_duration - patch_size) % stride > 0:
      number_patches += 1
    prob_sound = np.zeros((number_patches, sound_duration, 4))
    for i in range(number_patches):
      prob_sound[i, i * stride:i * stride + patch_size, :] = preds[preds_idx, :, :]
      preds_idx += 1
    if (sound_duration - patch_size) % stride > 0:
      prob_sound[number_patches - 1, sound_duration - patch_size:, :] = preds[preds_idx, :, :]
      preds_idx +=1

    probs_patch = tf.reduce_sum(prob_sound, axis=0)
    probs_patch_normalized = probs_patch /  (tf.reduce_sum(prob_sound, axis=[0,2])[:, tf.newaxis] + 1e-12)
    output_probs[idx] = probs_patch_normalized
    output_seqs[idx] = tf.argmax(probs_patch_normalized, axis=1)
  return output_probs, output_seqs

In [None]:
probs, predictions = process_unet_predictions(predictions_test, dataset_dict['val'], PATCH_SIZE, STRIDE)

325it [00:02, 138.88it/s]


The U-Net outputs estimates patch-by-patch, meaning it can output invalid heart sequences, e.g. going from S1 directly to S2 (we are assuming that screening is not interrupted here).

Implement a deconding function that takes the sequence predicted patch-by-patch that you recovered using `process_unet_predictions` and processes the output to be valid. You can change the output according to any criteria you want.

The function you implement should pass the unit test in the cell bellow.

In [None]:
def your_deconding_function(seq, num_states=4):
  """your code here"""
  return seq

test_seq = np.array([1, 1, 1, 2, 1, 1, 2, 2, 3, 3, 0, 3])
exp_seq = np.array([1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 0, 0])

out_seq = your_deconding_function(test_seq)
assert np.all(exp_seq == out_seq)

In [None]:
from tqdm import tqdm
test_labels = dataset_dict['val']['heart_state_labels']
ground_truth = np.array([np.argmax(y, axis=1) for y in tqdm(test_labels)], dtype=object)
predictions = np.array([your_deconding_function(prediction.numpy()) for prediction in predictions], dtype=object)

We will follow [Schmidt et al.](https://iopscience.iop.org/article/10.1088/0967-3334/31/4/004/pdf). A sound is true positive (TP) or correctly located if
the middle of the detected sound is closer than 60 ms to the middle of a similarly predefined sound, all other detected sounds were defined as false positive (FP).

Sensitivity is defined as:
\begin{equation}
\text{Sensitivity} = \frac{\text{number of TP sounds}}{\text{total number of S1 and S2 sounds}}
\end{equation}

and positive predictivity ($P_+$):
\begin{equation}
P_+ = \frac{\text{number of TP sounds}}{\text{number of TP sounds} + \text{number of FP sounds}}.
\end{equation}

These metrics are tricky to implement so we provide them beforehand (the authors may or may not had a little help from o1-preview for this step :))

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score
def extract_state_runs(labels, desired_states):
    """
    Extract continuous runs of the desired states from labels.

    Args:
        labels: numpy array of labels.
        desired_states: set of desired state values.

    Returns:
        A list of dictionaries with keys:
            'start': start index of the run
            'end': end index of the run (inclusive)
            'midpoint': midpoint index of the run
            'state': the state value (0 or 2)
    """
    # Ensure labels is a 1D array
    labels = np.asarray(labels).flatten()

    runs = []
    N = len(labels)
    in_run = False
    run_start = 0
    run_state = None

    for i in range(N):
        label_i = labels[i]
        # If label_i is an array (e.g., from a structured array), extract scalar
        if isinstance(label_i, np.ndarray):
            label_i = label_i.item()
        if label_i in desired_states:
            if not in_run:
                # Start of a new run
                in_run = True
                run_start = i
                run_state = label_i
        else:
            if in_run:
                # End of the run
                run_end = i - 1
                midpoint = (run_start + run_end) // 2
                runs.append({
                    'start': run_start,
                    'end': run_end,
                    'midpoint': midpoint,
                    'state': run_state
                })
                in_run = False
                run_state = None
    # Check if we're still in a run at the end
    if in_run:
        run_end = N - 1
        midpoint = (run_start + run_end) // 2
        runs.append({
            'start': run_start,
            'end': run_end,
            'midpoint': midpoint,
            'state': run_state
        })
    return runs

def compute_ppv_sensitivity(ground_truth, predictions, sample_rate, threshold=60e-3):
    """
    Compute PPV and sensitivity for states 0 and 2.

    Args:
        ground_truth: numpy array of ground truth labels.
        predictions: numpy array of predicted labels.
        sample_rate: sampling rate in Hz.

    Returns:
        ppv: Positive Predictive Value.
        sensitivity: Sensitivity (Recall).
    """
    # Ensure ground_truth and predictions are 1D arrays
    ground_truth = np.asarray(ground_truth).flatten()
    predictions = np.asarray(predictions).flatten()

    # Desired states
    desired_states = {0, 2}

    # Maximum distance in samples (treshold in seconds vs fs)
    max_distance_samples = int(threshold * sample_rate)

    # Extract runs from ground truth and predictions
    gt_runs = extract_state_runs(ground_truth, desired_states)
    pred_runs = extract_state_runs(predictions, desired_states)

    # Get midpoints and states
    gt_midpoints = np.array([run['midpoint'] for run in gt_runs])
    gt_states = np.array([run['state'] for run in gt_runs])

    pred_midpoints = np.array([run['midpoint'] for run in pred_runs])
    pred_states = np.array([run['state'] for run in pred_runs])

    # Initialize matches
    matched_gt_indices = set()
    matched_pred_indices = set()

    # Build potential matches
    potential_matches = []
    for i, (p_mid, p_state) in enumerate(zip(pred_midpoints, pred_states)):
        for j, (gt_mid, gt_state) in enumerate(zip(gt_midpoints, gt_states)):
            if gt_state == p_state:
                distance = abs(p_mid - gt_mid)
                if distance <= max_distance_samples:
                    potential_matches.append((i, j, distance))

    # Sort potential matches by distance
    potential_matches.sort(key=lambda x: x[2])

    # Perform matching
    TP = 0
    for i, j, d in potential_matches:
        if i not in matched_pred_indices and j not in matched_gt_indices:
            matched_pred_indices.add(i)
            matched_gt_indices.add(j)
            TP += 1

    # Compute FP and FN
    total_pred = len(pred_midpoints)
    total_gt = len(gt_midpoints)
    FP = total_pred - len(matched_pred_indices)
    FN = total_gt - len(matched_gt_indices)

    # Compute PPV and Sensitivity
    ppv = TP / (TP + FP) if (TP + FP) > 0 else 0
    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0

    return ppv, sensitivity

In [None]:
def compute_schmidt_metrics(ground_truth, sequences, sampling_rate):
  ppvs, sensitivities, accuracies = [], [], []
  for i in tqdm(range(len(ground_truth))):
    ppv, sensitivity = compute_ppv_sensitivity(ground_truth[i],
                                               sequences[i],
                                               sampling_rate)
    ppvs.append(ppv)
    sensitivities.append(sensitivity)
    accuracies.append(accuracy_score(ground_truth[i], sequences[i]))
  return np.array(ppvs), np.array(sensitivities), np.array(accuracies)


ppv, sens, acc = compute_schmidt_metrics(ground_truth, predictions, sampling_rate=50)

In [None]:
np.mean(ppv), np.mean(sens), np.mean(acc)

Inspect a result where your metrics had satisfactory performance and one where the result was less positive. Discuss the differences.

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def visualize_predictions(ground_truth, seqs, idx):

  # Define the window width in terms of seconds and convert to the corresponding sample range
  window_duration = 0.1  # 0.06 ms in seconds
  sample_interval = 1 / 50  # Time per sample in seconds (20 ms per sample at 50 Hz)

  # Calculate the equivalent width in terms of sample indices (will be <1)
  window_width_samples = window_duration / sample_interval

  # Create the plot
  plt.figure(figsize=(24, 6))

  # Plot ground truth and predictions with discrete markers and dotted lines
  plt.plot(ground_truth[idx], 'o--', label='Ground Truth', markersize=6)
  plt.plot(seqs[idx], 'o--', color='red', label='Predictions', markersize=6)


  # Set labels and legend
  plt.title(f'Signal at idx {idx} with 0.06 ms reference window')
  plt.xlabel('Sample Index')
  plt.ylabel('Amplitude')
  plt.legend()
  plt.grid(True)
  plt.show()



In [None]:
idx = 0
visualize_predictions(ground_truth, predictions, idx)
ppv[idx], sens[idx], acc[idx]

In [None]:
idx = 100
visualize_predictions(ground_truth, predictions, idx)
ppv[idx], sens[idx], acc[idx]

Although our model is far from perfect, we can still see that is indeed learning the rhythmic characteritics of the heart sound sequence.

What you would do this perliminary model?


*  Do you think the filter-extraction plays a pivotal role? What if only used the original waveform without envelograms?
*  We know how the heart sounds are described statistically (at least during screening). They have a strong Markovian character. Could we somehow use this to make our model aware of this process?
* Do you think we could have a smarter way of decoding the heart states?
*   What architectural changes could be implemented? How about recurrent networks or Transformers?

Feel free to reach out if you want to discuss these questions!
Contacts can be found at: https://miguelmartins.github.io/


