Skip to content

Support diffusion models: Add Dream 7B #14644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 16, 2025
Merged

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Jul 12, 2025

This PR adds basic support for diffusion models, using Dream 7B instruct as base. DiffuCoder-7B is built on the same arch so it should be trivial to add after this

It introduces a diffusion step which samples the entire context on the CPU and unmasks tokens following a timestep regime.At the moment it is slow because we don't do sampling on the GPU, so this model is currently CPU bound. Additionally the context is limited to 2048, which may not make it useful apart from educational purposes. Added diffusion-cli for a basic example.

If there is interest, I can spend some time implementing sampling on the GPU. Theoretically, this can out-perform AR models and very high tokens/sec speeds have been advertised, so I think it makes sense to add support for these.

I uploaded two GGUF variants here on HF

Example command -
./build/bin/llama-diffusion-cli -m dream7b-q8-0.gguf -p "Write python code to train MNIST in pytorch" --temp 0.2 --diffusion-algorithm 0 --diffusion-alg-temp 0.2 -ngl 99 --top_p 0.95 --diffusion-steps 128

Another cool/gimmicky thing is you can see the diffusion unfold, which you can do with this tool using --diffusion-visual . Here it using the prompt as given in the HF example.

demo

TODO:

  • Add profiling for sampling vs decode
  • Add multithreading to sampler and measure

@github-actions github-actions bot added examples python python script changes labels Jul 12, 2025
@am17an am17an requested review from ggerganov and CISC July 12, 2025 07:50
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Which part of the sampling exactly is so slow?

Is it correct to say that each step runs pp2048 + sampling?

@am17an
Copy link
Collaborator Author

am17an commented Jul 12, 2025

Cool!

Which part of the sampling exactly is so slow?

Is it correct to say that each step runs pp2048 + sampling?

Yes that's correct.

Why is it slow (acc to me) - there is a logits buffer of size (n_ctx x n_vocab) at each step where we need to sample from, I think transferring this buffer from GPU->CPU at each step plus sampling is killing performance, the pytorch HF version where the buffer remains on the GPU is much faster

@CISC CISC added the hot Something that is hot label Jul 12, 2025
@ggerganov
Copy link
Member

Could you add the command that you use to the OP?

@ggerganov
Copy link
Member

You can avoid the KV cache with this patch:

diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 43bf13856..dcc80a052 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -7694,7 +7694,7 @@ struct llm_build_dream : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_no_cache();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16192,6 +16192,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         // Models that need specific instantiation should be handled in the
         // switch statement
         case LLM_ARCH_BERT:
+        case LLM_ARCH_DREAM:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:

You would need to extend llama_context::encode() to support logits extraction:

diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index a2bb82e93..c89cfd236 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -736,6 +736,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     const auto & hparams = model.hparams;
 
     const int64_t n_embd = hparams.n_embd;
+    const int32_t n_vocab = model.vocab.n_tokens();
 
     // note: during encode, we always pass the full sequence starting from pos = 0
     if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
@@ -795,10 +796,22 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
+    auto * t_logits = res->get_logits();
     auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
+    // extract logits
+    if (t_logits && n_outputs > 0) {
+        ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
+        GGML_ASSERT(backend_res != nullptr);
+        GGML_ASSERT(logits != nullptr);
+
+        if (n_outputs) {
+            ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float));
+        }
+    }
+
     // extract embeddings
-    if (t_embd) {
+    if (cparams.embeddings && t_embd) {
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
         GGML_ASSERT(backend_embd != nullptr);
 
@@ -2881,7 +2894,7 @@ llama_token * llama_diffusion_generate(llama_context * ctx, const llama_token *
         }
 
         // Clear KV cache, TODO: can we configure to not use kv-cache?
-        llama_memory_clear(llama_get_memory(ctx), true);
+        //llama_memory_clear(llama_get_memory(ctx), true);
 
         llama_batch batch = llama_batch_init(max_length, 0, 1);
         batch.n_tokens    = max_length;

Btw, using -fa should speed-up the decoding.

Multi-threading the sampling over the mask positions should improve the performance significantly compared to the current implementation.


This is definitely interesting. I think we should support it and try to keep the changes to libllama to a minimum for now. GPU sampling will be done separately - we can probably spec this soon and start working on it.

@am17an
Copy link
Collaborator Author

am17an commented Jul 12, 2025

I'm making the changes and I plan to keep only llm_arch_is_diffusion and llama_vocab_get_mask in the API and move everything else to tools/diffusion

@ggerganov
Copy link
Member

I'm making the changes and I plan to keep only llm_arch_is_diffusion and llama_vocab_get_mask in the API and move everything else to tools/diffusion

For now, it should be examples/diffusion. To be promoted to a tool, it needs to mature because tools are shipped with packages and thus require more effort to maintain.

@am17an
Copy link
Collaborator Author

am17an commented Jul 13, 2025

Actually it looks like the 87% of the time in the GPU case is spent transferring the logits from the GPU => CPU, in the CPU case, the decode step is the bottleneck, so I don't think multithreading on the sampler is going to help here.

Timing Summary for dream-7b-q8

on GPU

Total decode time: 1521.95ms (avg: 11.89ms/step)
Total logits transfer time: 13247.66ms (avg: 103.50ms/step)
Total sampling time: 420.27ms (avg: 3.28ms/step)
Total generation time: 15189.87ms  
Decode/Logits Transfer/Sampling ratio: 0.10/0.87/0.03

on CPU

Total decode time: 132348.18ms (avg: 1033.97ms/step)
Total logits transfer time: 1.67ms (avg: 0.01ms/step)
Total sampling time: 419.97ms (avg: 3.28ms/step)
Total generation time: 132769.82ms
Decode/Logits Transfer/Sampling ratio: 1.00/0.00/0.00

@ggerganov
Copy link
Member

Most likely you didn't measure correctly the time. With CUDA, the llama_decode is asynchronous so your "Total logits transfer time" actually includes the decode time. Add a llama_synchronize() before getting the logits to make sure you are measuring just the transfer.

@am17an
Copy link
Collaborator Author

am17an commented Jul 13, 2025

Most likely you didn't measure correctly the time. With CUDA, the llama_decode is asynchronous so your "Total logits transfer time" actually includes the decode time. Add a llama_synchronize() before getting the logits to make sure you are measuring just the transfer.

Okay, I re-did the calculation and also timed around llama_synchronize, so I think my original points stands in the the transfer is still the bottleneck unless I'm still doing something wrong

Total decode time: 1293.82ms (avg: 10.11ms/step)
Total logits time: 1.93ms (avg: 0.02ms/step)
Total sampling time: 419.34ms (avg: 3.28ms/step)
Total sync time: 13386.34ms (avg: 104.58ms/step)

@ggerganov
Copy link
Member

ggerganov commented Jul 13, 2025

No, let me explain better:

// measure this for the total decode + GPU->CPU transfer time
{
  llama_decode()
  llama_synchronize()
}

You cannot measure the transfer alone. If you want to do that, you have to modify this copy in llama_context.cpp to not be asynchronous:

    // extract logits
   if (t_logits && n_outputs > 0) {
        ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
        GGML_ASSERT(backend_res != nullptr);
        GGML_ASSERT(logits != nullptr);

        if (n_outputs) {
            ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float));
        }
    }

For example:

llama_synchronize()

// measure this section to get the GPU->CPU transfer time
{
    ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float));
    llama_synchronize()
}

@am17an
Copy link
Collaborator Author

am17an commented Jul 13, 2025

Ok, I did this

diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index fced6564..3cff75a3 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -806,7 +806,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
         GGML_ASSERT(logits != nullptr);

         if (n_outputs) {
+            auto t1 = ggml_time_us();
             ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float));
+            llama_synchronize(this);
+            auto t2 = ggml_time_us();
+            LLAMA_LOG_INFO("logits time: %0.2fms\n", (t2 - t1)/1000.0);
         }
     }

The average here is 105ms/step (total step time being 115ms)

@ggerganov
Copy link
Member

You are missing sync before the first ggml get time

@am17an
Copy link
Collaborator Author

am17an commented Jul 13, 2025

Ah okay, thank you for your patience! logits now ~15ms, which makes sense. The logits size is about 311 Mb (512 x 152064 x 4B), so we're achieving about 20GBps which is consistent with PCIe speeds.

@am17an
Copy link
Collaborator Author

am17an commented Jul 13, 2025

Turns out it was the sampling all along, I was copying a huge vector + doing allocation 🤦‍♂️
Before Total generating time: 132769.82ms
After Total generation time: 19338.11ms

@ggerganov
Copy link
Member

ggerganov commented Jul 13, 2025

Could you upload a Q4_0 variant (I can't fit the Q8_0 on my machine)? Nvm, I'll quantize the F16.

On your RTX 3090, is it faster using the F16 model?

@am17an
Copy link
Collaborator Author

am17an commented Jul 13, 2025

Could you upload a Q4_0 variant (I can't fit the Q8_0 on my machine)? Nvm, I'll quantize the F16.

On your RTX 3090, is it faster using the F16 model?

It's slightly faster using the f16 version, I can't say for sure, there's some variability in the results and both of them look the same

@CISC
Copy link
Collaborator

CISC commented Jul 13, 2025

Very nice, just tested DiffuCoder-7B-cpGRPO, works (but requires --diffusion-steps 512 to output anything but No code provided)!

Edit: Ok, these settings (from README.md) work too:

--temp 0.4 --diffusion-algorithm 3 --diffusion-alg-temp 0.0 --top_p 0.95 --diffusion-steps 256

Comment on lines 177 to 170
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
candidates[token_id].id = token_id;
}
Copy link
Collaborator Author

@am17an am17an Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov this is the loop which takes the most time in sampling. If the sampler array was SoA (llama_token_array even has a comment saying so) we could possibly just pass the logits directly here, that would speed up the sampling by 60-70% on my machine. But maybe there is a better way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not because it seems like a necessary cost to keep a pair of logit and their token_id, as this array gets sorted etc.

@ggerganov ggerganov self-requested a review July 13, 2025 15:49
Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's worth adding a README somewhere to avoid maintenance burden in the future.

common/arg.cpp Outdated
string_format("maximum generation length (default: %d)", params.diffusion.max_length),
[](common_params & params, int value) { params.diffusion.max_length = value; })
.set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg({ "--diffusion-visual" },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
add_opt(common_arg({ "--diffusion-visual" },
add_opt(common_arg({ "--diffusion-progressive" },

Using "visual" as wording may be confusing, as user may expect some short of image output

except FileNotFoundError:
self._set_vocab_gpt2()

def set_gguf_parameters(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add a GGUF metadata to say that the model is diffusion-based

self.gguf_writer.add_is_diffusion(True)

Then if user try to run it with llama-cli or llama-server, show an error:

You must run this model using llama-diffusion-cli

@am17an am17an requested review from ggerganov and CISC July 15, 2025 05:22
@am17an am17an force-pushed the add_dream_lm branch 2 times, most recently from d27f0f3 to da800c8 Compare July 15, 2025 05:25
@am17an am17an force-pushed the add_dream_lm branch 2 times, most recently from 56f5641 to 4b28884 Compare July 16, 2025 05:51
@am17an am17an merged commit ab14019 into ggml-org:master Jul 16, 2025
48 of 51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples hot Something that is hot python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants