diff --git a/lib/ogg/include/ogg/os_types.h b/lib/ogg/include/ogg/os_types.h
index e655a1d6..9d2da360 100644
--- a/lib/ogg/include/ogg/os_types.h
+++ b/lib/ogg/include/ogg/os_types.h
@@ -16,11 +16,16 @@
 #ifndef _OS_TYPES_H
 #define _OS_TYPES_H
 
+#include "esp_heap_caps.h"
+
+#define OGG_MALLOC_CAPS_1 (MALLOC_CAP_INTERNAL | MALLOC_CAP_8BIT)
+#define OGG_MALLOC_CAPS_2 (MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT)
+
 /* make it easy on the folks that want to compile the libs with a
    different malloc than stdlib */
-#define _ogg_malloc  malloc
-#define _ogg_calloc  calloc
-#define _ogg_realloc realloc
+#define _ogg_malloc(x) heap_caps_malloc_prefer(x, OGG_MALLOC_CAPS_1, OGG_MALLOC_CAPS_2)
+#define _ogg_calloc(x, y)  heap_caps_calloc_prefer(x, y, OGG_MALLOC_CAPS_1, OGG_MALLOC_CAPS_2)
+#define _ogg_realloc(x, y) heap_caps_realloc_prefer(x, y, OGG_MALLOC_CAPS_1, OGG_MALLOC_CAPS_2)
 #define _ogg_free    free
 
 #if defined(_WIN32)
diff --git a/src/audio/audio_decoder.cpp b/src/audio/audio_decoder.cpp
index 86394a37..78b0d3f5 100644
--- a/src/audio/audio_decoder.cpp
+++ b/src/audio/audio_decoder.cpp
@@ -119,6 +119,8 @@ void Decoder::Main() {
 }
 
 auto Decoder::BeginDecoding(std::shared_ptr<codecs::IStream> stream) -> bool {
+  // Ensure any previous codec is freed before creating a new one.
+  codec_.reset();
   codec_.reset(codecs::CreateCodecForType(stream->type()).value_or(nullptr));
   if (!codec_) {
     ESP_LOGE(kTag, "no codec found");
@@ -164,6 +166,10 @@ auto Decoder::ContinueDecoding() -> bool {
     timer_->AddSamples(res->samples_written);
   }
 
+  if (res->is_stream_finished) {
+    codec_.reset();
+  }
+
   return res->is_stream_finished;
 }
 
diff --git a/src/codecs/include/mad.hpp b/src/codecs/include/mad.hpp
index ef4e91f8..994e989c 100644
--- a/src/codecs/include/mad.hpp
+++ b/src/codecs/include/mad.hpp
@@ -38,16 +38,16 @@ class MadMp3Decoder : public ICodec {
   MadMp3Decoder& operator=(const MadMp3Decoder&) = delete;
 
  private:
-  auto SkipID3Tags(IStream &stream) -> void;
+  auto SkipID3Tags(IStream& stream) -> void;
   auto GetVbrLength(const mad_header& header) -> std::optional<uint32_t>;
   auto GetBytesUsed() -> std::size_t;
 
   std::shared_ptr<IStream> input_;
   SourceBuffer buffer_;
 
-  mad_stream stream_;
-  mad_frame frame_;
-  mad_synth synth_;
+  mad_stream* stream_;
+  mad_frame* frame_;
+  mad_synth* synth_;
 
   int current_sample_;
   bool is_eof_;
diff --git a/src/codecs/mad.cpp b/src/codecs/mad.cpp
index 224f7391..6e1546d8 100644
--- a/src/codecs/mad.cpp
+++ b/src/codecs/mad.cpp
@@ -12,6 +12,7 @@
 #include <cstring>
 #include <optional>
 
+#include "esp_heap_caps.h"
 #include "mad.h"
 
 #include "codec.hpp"
@@ -24,23 +25,44 @@ namespace codecs {
 
 static constexpr char kTag[] = "mad";
 
+static constexpr uint32_t kMallocCapsPrefer =
+    MALLOC_CAP_INTERNAL | MALLOC_CAP_8BIT;
+static constexpr uint32_t kMallocCapsFallback =
+    MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT;
+
 MadMp3Decoder::MadMp3Decoder()
-    : input_(), buffer_(), current_sample_(-1), is_eof_(false), is_eos_(false) {
-  mad_stream_init(&stream_);
-  mad_frame_init(&frame_);
-  mad_synth_init(&synth_);
+    : input_(),
+      buffer_(),
+      stream_(reinterpret_cast<mad_stream*>(
+          heap_caps_malloc_prefer(sizeof(mad_stream),
+                                  kMallocCapsPrefer,
+                                  kMallocCapsFallback))),
+      frame_(reinterpret_cast<mad_frame*>(
+          heap_caps_malloc_prefer(sizeof(mad_frame),
+                                  kMallocCapsPrefer,
+                                  kMallocCapsFallback))),
+      synth_(reinterpret_cast<mad_synth*>(
+          heap_caps_malloc_prefer(sizeof(mad_synth),
+                                  kMallocCapsPrefer,
+                                  kMallocCapsFallback))),
+      current_sample_(-1),
+      is_eof_(false),
+      is_eos_(false) {
+  mad_stream_init(stream_);
+  mad_frame_init(frame_);
+  mad_synth_init(synth_);
 }
 MadMp3Decoder::~MadMp3Decoder() {
-  mad_stream_finish(&stream_);
-  mad_frame_finish(&frame_);
-  mad_synth_finish(&synth_);
+  mad_stream_finish(stream_);
+  mad_frame_finish(frame_);
+  mad_synth_finish(synth_);
 }
 
 auto MadMp3Decoder::GetBytesUsed() -> std::size_t {
-  if (stream_.next_frame) {
-    return stream_.next_frame - stream_.buffer;
+  if (stream_->next_frame) {
+    return stream_->next_frame - stream_->buffer;
   } else {
-    return stream_.bufend - stream_.buffer;
+    return stream_->bufend - stream_->buffer;
   }
 }
 
@@ -60,17 +82,17 @@ auto MadMp3Decoder::OpenStream(std::shared_ptr<IStream> input)
     eof = buffer_.Refill(input_.get());
 
     buffer_.ConsumeBytes([&](cpp::span<std::byte> buf) -> size_t {
-      mad_stream_buffer(&stream_,
+      mad_stream_buffer(stream_,
                         reinterpret_cast<const unsigned char*>(buf.data()),
                         buf.size_bytes());
 
-      while (mad_header_decode(&header, &stream_) < 0) {
-        if (MAD_RECOVERABLE(stream_.error)) {
+      while (mad_header_decode(&header, stream_) < 0) {
+        if (MAD_RECOVERABLE(stream_->error)) {
           // Recoverable errors are usually malformed parts of the stream.
           // We can recover from them by just retrying the decode.
           continue;
         }
-        if (stream_.error == MAD_ERROR_BUFLEN) {
+        if (stream_->error == MAD_ERROR_BUFLEN) {
           return GetBytesUsed();
         }
         eof = true;
@@ -120,19 +142,19 @@ auto MadMp3Decoder::DecodeTo(cpp::span<sample::Sample> output)
     }
 
     buffer_.ConsumeBytes([&](cpp::span<std::byte> buf) -> size_t {
-      mad_stream_buffer(&stream_,
+      mad_stream_buffer(stream_,
                         reinterpret_cast<const unsigned char*>(buf.data()),
                         buf.size());
 
       // Decode the next frame. To signal errors, this returns -1 and
       // stashes an error code in the stream structure.
-      while (mad_frame_decode(&frame_, &stream_) < 0) {
-        if (MAD_RECOVERABLE(stream_.error)) {
+      while (mad_frame_decode(frame_, stream_) < 0) {
+        if (MAD_RECOVERABLE(stream_->error)) {
           // Recoverable errors are usually malformed parts of the stream.
           // We can recover from them by just retrying the decode.
           continue;
         }
-        if (stream_.error == MAD_ERROR_BUFLEN) {
+        if (stream_->error == MAD_ERROR_BUFLEN) {
           if (is_eof_) {
             is_eos_ = true;
           }
@@ -146,7 +168,7 @@ auto MadMp3Decoder::DecodeTo(cpp::span<sample::Sample> output)
 
       // We've successfully decoded a frame! Now synthesize samples to write
       // out.
-      mad_synth_frame(&synth_, &frame_);
+      mad_synth_frame(synth_, frame_);
       current_sample_ = 0;
       return GetBytesUsed();
     });
@@ -154,16 +176,16 @@ auto MadMp3Decoder::DecodeTo(cpp::span<sample::Sample> output)
 
   size_t output_sample = 0;
   if (current_sample_ >= 0) {
-    while (current_sample_ < synth_.pcm.length) {
-      if (output_sample + synth_.pcm.channels >= output.size()) {
+    while (current_sample_ < synth_->pcm.length) {
+      if (output_sample + synth_->pcm.channels >= output.size()) {
         // We can't fit the next full frame into the buffer.
         return OutputInfo{.samples_written = output_sample,
                           .is_stream_finished = false};
       }
 
-      for (int channel = 0; channel < synth_.pcm.channels; channel++) {
+      for (int channel = 0; channel < synth_->pcm.channels; channel++) {
         output[output_sample++] =
-            sample::FromMad(synth_.pcm.samples[channel][current_sample_]);
+            sample::FromMad(synth_->pcm.samples[channel][current_sample_]);
       }
       current_sample_++;
     }
@@ -212,13 +234,13 @@ auto MadMp3Decoder::SkipID3Tags(IStream& stream) -> void {
  */
 auto MadMp3Decoder::GetVbrLength(const mad_header& header)
     -> std::optional<uint32_t> {
-  if (!stream_.this_frame || !stream_.next_frame ||
-      stream_.next_frame <= stream_.this_frame ||
-      (stream_.next_frame - stream_.this_frame) < 48) {
+  if (!stream_->this_frame || !stream_->next_frame ||
+      stream_->next_frame <= stream_->this_frame ||
+      (stream_->next_frame - stream_->this_frame) < 48) {
     return {};
   }
 
-  int mpeg_version = (stream_.this_frame[1] >> 3) & 0x03;
+  int mpeg_version = (stream_->this_frame[1] >> 3) & 0x03;
 
   int xing_offset = 0;
   switch (mpeg_version) {
@@ -244,17 +266,17 @@ auto MadMp3Decoder::GetVbrLength(const mad_header& header)
   uint32_t frames_count = 0;
   // TODO(jacqueline): we should also look up any toc fields here, to make
   // seeking faster.
-  if (std::memcmp(stream_.this_frame + xing_offset, "Xing", 4) == 0 ||
-      std::memcmp(stream_.this_frame + xing_offset, "Info", 4) == 0) {
+  if (std::memcmp(stream_->this_frame + xing_offset, "Xing", 4) == 0 ||
+      std::memcmp(stream_->this_frame + xing_offset, "Info", 4) == 0) {
     /* Xing header to get the count of frames for VBR */
-    frames_count_raw = stream_.this_frame + xing_offset + 8;
+    frames_count_raw = stream_->this_frame + xing_offset + 8;
     frames_count = ((uint32_t)frames_count_raw[0] << 24) +
                    ((uint32_t)frames_count_raw[1] << 16) +
                    ((uint32_t)frames_count_raw[2] << 8) +
                    ((uint32_t)frames_count_raw[3]);
-  } else if (std::memcmp(stream_.this_frame + xing_offset, "VBRI", 4) == 0) {
+  } else if (std::memcmp(stream_->this_frame + xing_offset, "VBRI", 4) == 0) {
     /* VBRI header to get the count of frames for VBR */
-    frames_count_raw = stream_.this_frame + xing_offset + 14;
+    frames_count_raw = stream_->this_frame + xing_offset + 14;
     frames_count = ((uint32_t)frames_count_raw[0] << 24) +
                    ((uint32_t)frames_count_raw[1] << 16) +
                    ((uint32_t)frames_count_raw[2] << 8) +