Skip to content

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Jun 24, 2025

target #14285

Overview

Improve multi-sequence decoding performance by avoiding the cross-sequence attention compute.

Description

One significant drawback of the unified KV cache is that it leads to performing a lot of unnecessary computation in the attention when the unified buffer is shared between many large independent sequences. The reason is that we have to view this buffer continuously and therefore we end up computing large potions of "cross-sequence attention" which we then simply discard.

With this change, we add option to split the unified KV cache buffer into multiple buffers - one for each sequence. This decouples the sequences from each other and improves the performance and memory usage of the attention when more than one sequence is used. To achieve that, when the batch reaches the attention, we split it into multiple "streams":

llama.cpp/src/llama-graph.cpp

Lines 1035 to 1044 in c96c48c

// split the batch into streams if needed
const auto n_stream = k->ne[3];
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);

Each stream has its own KV cache buffer and thus no longer "sees" the rest of the other streams - it attends only to the tokens that belong to the same stream.

With this approach we now have 2 modes:

  • The vanilla "unified" approach which we always used until now - all sequences are assigned to a single stream
  • The new "split" approach - each sequence is assigned to a separate stream

The new "split" mode is enabled by default. However it requires the LLAMA_SET_ROWS=1 environment variable to be set. Otherwise, a warning will be printed and the context will fallback to "unified" mode. In the future, after there is enough ggml_set_rows() coverage in the backends (#14661) this will become the default mode.

To force the old "unified" mode, use --kv-unified CLI arg.

API Changes

  • Add bool llama_context_params::kv_unified. Default is false

llama.cpp/include/llama.h

Lines 336 to 340 in fb8150d

// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
};

Testing

Use LLAMA_SET_ROWS=1 llama-[command] ...

Qwen 2.5 Coder 3B Q8_0, M2 Ultra

# master
make -j && ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.00.604.032 I llama_kv_cache_unified:      Metal KV buffer size =  4680.00 MiB
0.00.953.209 I llama_kv_cache_unified: size = 4680.00 MiB (133120 cells,  36 layers, 32 seqs), K (f16): 2340.00 MiB, V (f16): 2340.00 MiB
0.01.016.945 I llama_context:      Metal compute buffer size =  1624.05 MiB
0.01.016.947 I llama_context:        CPU compute buffer size =  1056.05 MiB
0.01.016.947 I llama_context: graph nodes  = 1195
0.01.016.947 I llama_context: graph splits = 2
main: n_kv_max = 133120, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.403 |   729.71 |    1.403 |   729.66 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.381 |   741.44 |    1.381 |   741.37 |
|   512 |     32 |   32 |  17408 |    5.320 |  3079.72 |    2.052 |   498.98 |    7.372 |  2361.33 |
|  1024 |     32 |   32 |  33792 |   11.632 |  2817.15 |    2.715 |   377.16 |   14.347 |  2355.40 |
|  2048 |     32 |   32 |  66560 |   27.419 |  2390.20 |    4.052 |   252.73 |   31.470 |  2115.00 |
|  4096 |     32 |   32 | 132096 |   71.549 |  1831.92 |    6.664 |   153.66 |   78.213 |  1688.93 |


# PR
make -j && LLAMA_SET_ROWS=1 ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.00.584.467 I llama_kv_cache_unified:      Metal KV buffer size =  4896.00 MiB
0.00.952.799 I llama_kv_cache_unified: size = 4896.00 MiB (  4352 cells,  36 layers, 32/32 seqs), K (f16): 2448.00 MiB, V (f16): 2448.00 MiB
0.01.002.436 I llama_context:      Metal compute buffer size =  1219.00 MiB
0.01.002.438 I llama_context:        CPU compute buffer size =    50.05 MiB
0.01.002.438 I llama_context: graph nodes  = 1231
0.01.002.438 I llama_context: graph splits = 2
main: n_kv_max = 139264, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.339 |   764.92 |    1.339 |   764.85 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.332 |   768.79 |    1.332 |   768.69 |
|   512 |     32 |   32 |  17408 |    4.903 |  3341.42 |    1.499 |   682.93 |    6.403 |  2718.84 |
|  1024 |     32 |   32 |  33792 |   10.057 |  3258.12 |    1.569 |   652.46 |   11.627 |  2906.40 |
|  2048 |     32 |   32 |  66560 |   21.213 |  3089.47 |    1.754 |   583.79 |   22.967 |  2898.10 |
|  4096 |     32 |   32 | 132096 |   46.713 |  2805.91 |    2.107 |   486.09 |   48.819 |  2705.81 |

Geamma 3 4B Q8_0, M2 Ultra

# master
make -j && ./bin/llama-batched-bench -m ../models/gemma-3-4b/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.01.609.907 I llama_kv_cache_unified_iswa: creating non-SWA KV cache, size = 133120 cells
0.01.703.014 I llama_kv_cache_unified:      Metal KV buffer size =  2600.00 MiB
0.01.902.274 I llama_kv_cache_unified: size = 2600.00 MiB (133120 cells,   5 layers, 32 seqs), K (f16): 1300.00 MiB, V (f16): 1300.00 MiB
0.01.902.278 I llama_kv_cache_unified_iswa: creating     SWA KV cache, size = 34816 cells
0.02.040.114 I llama_kv_cache_unified:      Metal KV buffer size =  3944.00 MiB
0.02.325.408 I llama_kv_cache_unified: size = 3944.00 MiB ( 34816 cells,  29 layers, 32 seqs), K (f16): 1972.00 MiB, V (f16): 1972.00 MiB
0.02.403.614 I llama_context:      Metal compute buffer size =  2068.00 MiB
0.02.403.616 I llama_context:        CPU compute buffer size =  1332.09 MiB
0.02.403.617 I llama_context: graph nodes  = 1335
0.02.403.617 I llama_context: graph splits = 2
main: n_kv_max = 133120, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.843 |   555.52 |    1.844 |   555.44 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.800 |   569.00 |    1.800 |   568.94 |
|   512 |     32 |   32 |  17408 |    6.341 |  2583.88 |    3.601 |   284.33 |    9.942 |  1750.90 |
|  1024 |     32 |   32 |  33792 |   13.832 |  2369.03 |    5.442 |   188.18 |   19.273 |  1753.29 |
|  2048 |     32 |   32 |  66560 |   31.034 |  2111.78 |    6.343 |   161.43 |   37.377 |  1780.77 |
|  4096 |     32 |   32 | 132096 |   69.326 |  1890.65 |    7.456 |   137.33 |   76.783 |  1720.39 |

# PR
make -j && LLAMA_SET_ROWS=1 ./bin/llama-batched-bench -m ../models/gemma-3-4b/ggml-model-q8_0.gguf -c 133120 -b 2048 -ub 2048 -npp 0,0,512,1024,2048,4096 -ntg 32 -npl 32 -fa

0.00.505.130 I llama_kv_cache_unified_iswa: creating non-SWA KV cache, size = 4352 cells
0.00.603.948 I llama_kv_cache_unified:      Metal KV buffer size =  2720.00 MiB
0.00.813.515 I llama_kv_cache_unified: size = 2720.00 MiB (  4352 cells,   5 layers, 32/32 seqs), K (f16): 1360.00 MiB, V (f16): 1360.00 MiB
0.00.813.520 I llama_kv_cache_unified_iswa: creating     SWA KV cache, size = 3072 cells
0.01.198.824 I llama_kv_cache_unified:      Metal KV buffer size = 11136.00 MiB
0.01.986.031 I llama_kv_cache_unified: size = 11136.00 MiB (  3072 cells,  29 layers, 32/32 seqs), K (f16): 5568.00 MiB, V (f16): 5568.00 MiB
0.02.059.335 I llama_context:      Metal compute buffer size =  2068.00 MiB
0.02.059.340 I llama_context:        CPU compute buffer size =    78.09 MiB
0.02.059.340 I llama_context: graph nodes  = 1369
0.02.059.340 I llama_context: graph splits = 2
main: n_kv_max = 139264, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.577 |   649.36 |    1.577 |   649.26 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.568 |   652.99 |    1.568 |   652.86 |
|   512 |     32 |   32 |  17408 |    5.884 |  2784.73 |    1.769 |   578.77 |    7.653 |  2274.73 |
|  1024 |     32 |   32 |  33792 |   12.261 |  2672.46 |    1.874 |   546.44 |   14.135 |  2390.61 |
|  2048 |     32 |   32 |  66560 |   25.831 |  2537.12 |    1.962 |   522.01 |   27.793 |  2394.89 |
|  4096 |     32 |   32 | 132096 |   54.077 |  2423.79 |    2.065 |   496.00 |   56.142 |  2352.90 |

Using a more real-world example with llama-parallel:

# master
make -j && ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 32 -ns 128 -s 1 -c 16384 -fa

# PR
make -j && LLAMA_SET_ROWS=1 ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 32 -ns 128 -s 1 -c 4096 -fa

TODO

  • FA path
  • Non-FA path
  • Metal FA
  • Metal non-FA
  • CPU FA
  • CPU non-FA
  • ggml_soft_max_ext() support for virtual sequences
  • llama_memory_seq_cp support for virtual sequences
  • iSWA
  • split_equal support sequential ids
  • CUDA
  • Vulkan
  • etc.
  • more consistent sequence/virtual sequence naming
  • better term than "virtual sequence"?
  • env LLAMA_HT become regular compute parameter
  • Fix n_ctx meaning (total vs per-sequence)
  • Check input batch for no coupled sequences when HT is on
  • Require n_embd_v_gqa(il) == const when FA is off (no longer needed)
  • Save/load state

Next PRs

  • Optimize parallel encoding via (split_equal + padding) and stream split [TAG_NO_CACHE_PAD]
  • Disable and remove the defrag code when ggml_set_rows() is fully adopted
  • Add option to llama-parallel to use different RNG seeds for the different clients

@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jun 24, 2025
@JohannesGaessler
Copy link
Collaborator

Right now I am comparatively less busy with my PhD so it would be a good time for me to write CUDA code that is still missing, if there is any.

@ggerganov
Copy link
Member Author

ggerganov commented Jun 24, 2025

For now, these are the necessary CUDA changes:

  • Add ggml_set_rows() support (need PR towards ggml : add ggml_set_rows #14274, can already start implementing this)
  • Extend ggml_flash_attn_ext() to support n_seq dim if it does not yet:
// old
    // q:    [n_embd_k, n_batch,     n_head,    1]
    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

// new - supports `n_seq` dimension:
    // q:    [n_embd_k, n_batch,     n_head,    n_seq]
    // k:    [n_embd_k, n_kv,        n_head_kv, n_seq]
    // v:    [n_embd_v, n_kv,        n_head_kv, n_seq] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, n_seq,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   n_seq] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

CPU might also need to be extended (not sure yet)

  • Extend ggml_soft_max_ext to support n_seq dim if it does not yet in a similar way. Also not sure about the CPU state.

Edit: the CPU versions of ggml_soft_max_ext() and ggml_flash_attn_ext() are now correct and can be used as a reference.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ab2a2bb to 1b74b9d Compare June 24, 2025 17:24
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch 3 times, most recently from c246784 to 06bb08a Compare June 27, 2025 14:35
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch 3 times, most recently from 82277da to 4534123 Compare June 30, 2025 14:08
@ggerganov ggerganov mentioned this pull request Jul 1, 2025
5 tasks
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch from 2f577c5 to 30b4d4e Compare July 2, 2025 12:49
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from 6179578 to dfceb01 Compare July 2, 2025 18:20
Base automatically changed from gg/kv-cache-use-set-rows to master July 3, 2025 07:53
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch 2 times, most recently from eb5856c to ee0f729 Compare July 3, 2025 08:12
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ee0f729 to deae7cd Compare July 3, 2025 08:53
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch 2 times, most recently from 988d0cd to dbcfcaa Compare July 3, 2025 12:11
v_cells[s].resize(kv_size);
}

// by default, all sequence ids are mapped to the 0th virtual sequence
Copy link
Collaborator

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to understand the purpose of virtual sequences.

  • Is it to make the unified cache not unified?
    • Should it be a separate cache type instead?
  • why is n_seq_virt a number and not a bool of whether or not the cache is unified?
    • Is it to eventually allow n_seq_max % n_seq_virt == 0 for a partially-unified cache?
  • Are virtual sequences intended to be used with other types of caches eventually (e.g. recurrent)?
    • The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding).

Copy link
Member Author

@ggerganov ggerganov Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today I found a better term instead of "virtual sequences": "streams". So I'll use "streams" here and will update the code later today or tomorrow.

Is it to make the unified cache not unified?

Roughly yes. The user will be able to select between unified (i.e. single stream) or non-unified (multiple streams). Each mode has advantages in different scenarios. Single stream is good when the sequences share large common prefixes. Multiple streams are good when the sequences are mostly or completely independent from each other.

The first iteration will support 1 stream (i.e. same as master, vanilla unified KV cache) and n_seq_max streams. The latter means that each sequence id is assigned to a separate stream.

In theory, we could assign multiple sequence ids to the same stream to get a partially-unified KV cache, but this would need extra work and it might not have any useful applications. So out of scope for now.

Should it be a separate cache type instead?

There is too much similar logic. Still thinking about it, but most likely it will end up in the same cache type.

The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding)

Yes.

Comment on lines 73 to 75
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are sequential seq_ids required when virtual sequences are used?

Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?

I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are sequential seq_ids required when virtual sequences are used?

Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?

Yes, we make a view of the KV cache across the streams here:

ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
const uint64_t kv_size = get_size();
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size)*sinfo.s0);
}

The ns var is the number of streams that participate in the current ubatch. Their stream indices range from [s0, s1].

I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.

It should be possible. But I'm not sure if it would be worth - both in performance and in complexity. We can explore though.

@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*n_layer_cache*ggml_tensor_overhead()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 1 + intended? Why was it added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the per-stream views of the KV cache:

std::vector<ggml_tensor *> k_seq;
std::vector<ggml_tensor *> v_seq;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
}

These are used to implement the llama_memory_seq_cp(). This operation is no longer just assigning ids - it performs actual copy of the buffers in memory when we use multiple streams. Using these helper views, the operation is quite simple to implement:

bool is_full = true;
if (p0 > 0 && p0 + 1 < (int) get_size()) {
is_full = false;
}
if (p1 > 0 && p1 + 1 < (int) get_size()) {
is_full = false;
}
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
for (uint32_t il = 0; il < layers.size(); ++il) {
const auto & layer = layers[il];
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]);
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]);
// TODO: do we need synchronization here?
}
// TODO: support this:
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift");
v_cells[s1].reset();
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
if (v_cells[s0].seq_has(i, seq_id_src)) {
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i));
v_cells[s1].seq_add(i, seq_id_dst);
}
}
v_heads[s1] = v_heads[s0];
//for (uint32_t s = 0; s < n_seq_virt; ++s) {
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
//}
}

Though we cannot copy partial sequences when using multiple streams.

Comment on lines 498 to 508
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about decreasing sequence ids? Is the requirement that they are increasing, or that the included seq_ids should be in a contiguous range?

(decreasing sequence ids might not really happen often in practice though)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decreasing would also work - we just need continuous range. We can either add this, if there is an elegant way to search for this. Or we add some batch pre-processing step to move the complexity at a higher level. Or just delegate it to the user by warning when the batch is not arranged optimally.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from dbcfcaa to 33dcc3c Compare July 4, 2025 07:04
@ggerganov
Copy link
Member Author

There was indeed a bug - will be fixed with #14733.

@CISC
Copy link
Collaborator

CISC commented Jul 17, 2025

@ggerganov It looks like mask is not correctly padded with parallel processing:

LLAMA_SET_ROWS=1 ./llama-cli -m LFM2-1.2B-bf16.gguf -t 8 [...] --parallel 2
[...]
llama_kv_cache_unified: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to 512
llama_kv_cache_unified:        CPU KV buffer size =    48,00 MiB
llama_kv_cache_unified: size =   48,00 MiB (  4096 cells,   6 layers,  2/ 1 seqs), K (f16):   24,00 MiB, V (f16):   24,00 MiB
llama_memory_recurrent:        CPU RS buffer size =     0,31 MiB
llama_memory_recurrent: size =    0,31 MiB (     2 cells,  16 layers,  2 seqs), R (f32):    0,31 MiB, S (f32):    0,00 MiB
llama.cpp/ggml/src/ggml.c:3740: GGML_ASSERT(mask->ne[1] >= a->ne[1]) failed

Edit: With -fa it fails in ggml_flash_attn_ext instead of ggml_soft_max_ext:

llama.cpp/ggml/src/ggml.c:4768: GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big") failed

@ggerganov
Copy link
Member Author

@CISC The llama_memory_hybrid constructor has to respect the cparams.kv_unified value. Currently it is hardcoded to 1 (i.e. true). This patch fixes the issue, but a proper fix should be implemented to handle both unified and split modes:

diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
index d8e2086c8..ab6470bdf 100644
--- a/src/llama-memory-hybrid.cpp
+++ b/src/llama-memory-hybrid.cpp
@@ -31,21 +31,21 @@ llama_memory_hybrid::llama_memory_hybrid(
     hparams(model.hparams),
     mem_attn(new llama_kv_cache_unified(
         model,
         filter_attn == nullptr ?
             [&](int32_t il) { return !hparams.is_recurrent(il); }
             : filter_attn,
         type_k,
         type_v,
         v_trans,
         offload,
-        1,
+        false,
         kv_size,
         n_seq_max,
         n_pad,
         n_swa,
         swa_type
     )),
     mem_recr(new llama_memory_recurrent(
         model,
         filter_recr == nullptr ?
             [&](int32_t il) { return hparams.is_recurrent(il); }

@rujialiu
Copy link

rujialiu commented Jul 18, 2025

#14712.

I tested again after #14712 is merged. I can confirm that my use case is working without concurrent access. The VRAM usage is reduced from 11GB->10GB at startup, I've used it for a while with cline without any problem. Then I used claude code to test concurrent access and, after some time, got assertion failed. Here is some log at the end:

srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  3 | task 1855 | processing task
slot update_slots: id  3 | task 1855 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 13802
slot update_slots: id  3 | task 1855 | kv cache rm [4, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 2052, n_tokens = 2048, progress = 0.148384
slot update_slots: id  3 | task 1855 | kv cache rm [2052, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 4100, n_tokens = 2048, progress = 0.296769
slot update_slots: id  3 | task 1855 | kv cache rm [4100, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 6148, n_tokens = 2048, progress = 0.445153
slot update_slots: id  3 | task 1855 | kv cache rm [6148, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 8196, n_tokens = 2048, progress = 0.593537
slot update_slots: id  3 | task 1855 | kv cache rm [8196, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 10244, n_tokens = 2048, progress = 0.741921
slot update_slots: id  3 | task 1855 | kv cache rm [10244, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 12292, n_tokens = 2048, progress = 0.890306
slot update_slots: id  3 | task 1855 | kv cache rm [12292, end)
slot update_slots: id  3 | task 1855 | prompt processing progress, n_past = 13802, n_tokens = 1510, progress = 0.999710
slot update_slots: id  3 | task 1855 | prompt done, n_past = 13802, n_tokens = 1510
slot      release: id  3 | task 1855 | stop processing: n_past = 13841, truncated = 0
slot print_timing: id  3 | task 1855 |
prompt eval time =   14600.69 ms / 13798 tokens (    1.06 ms per token,   945.02 tokens per second)
       eval time =    3494.51 ms /    40 tokens (   87.36 ms per token,    11.45 tokens per second)
      total time =   18095.19 ms / 13838 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  2 | task 1902 | processing task
slot update_slots: id  2 | task 1902 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 297
slot update_slots: id  2 | task 1902 | kv cache rm [2, end)
slot update_slots: id  2 | task 1902 | prompt processing progress, n_past = 297, n_tokens = 295, progress = 0.993266
slot update_slots: id  2 | task 1902 | prompt done, n_past = 297, n_tokens = 295
srv  params_from_: Chat format: Generic
slot launch_slot_: id  3 | task 1904 | processing task
slot update_slots: id  3 | task 1904 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 13960
slot update_slots: id  3 | task 1904 | kv cache rm [13802, end)
slot update_slots: id  3 | task 1904 | prompt processing progress, n_past = 13960, n_tokens = 159, progress = 0.011318
slot update_slots: id  3 | task 1904 | prompt done, n_past = 13960, n_tokens = 159
slot      release: id  2 | task 1902 | stop processing: n_past = 319, truncated = 0
slot print_timing: id  2 | task 1902 |
prompt eval time =     381.36 ms /   295 tokens (    1.29 ms per token,   773.55 tokens per second)
       eval time =   12568.56 ms /    23 tokens (  546.46 ms per token,     1.83 tokens per second)
      total time =   12949.92 ms /   318 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
slot      release: id  3 | task 1904 | stop processing: n_past = 14016, truncated = 0
slot print_timing: id  3 | task 1904 |
prompt eval time =     910.12 ms /   158 tokens (    5.76 ms per token,   173.60 tokens per second)
       eval time =   14617.52 ms /    57 tokens (  256.45 ms per token,     3.90 tokens per second)
      total time =   15527.64 ms /   215 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  2 | task 1962 | processing task
slot update_slots: id  2 | task 1962 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 461
slot update_slots: id  2 | task 1962 | kv cache rm [259, end)
slot update_slots: id  2 | task 1962 | prompt processing progress, n_past = 461, n_tokens = 202, progress = 0.438178
slot update_slots: id  2 | task 1962 | prompt done, n_past = 461, n_tokens = 202
srv  params_from_: Chat format: Generic
slot launch_slot_: id  3 | task 1964 | processing task
slot update_slots: id  3 | task 1964 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 14301
slot update_slots: id  3 | task 1964 | kv cache rm [13965, end)
slot update_slots: id  3 | task 1964 | prompt processing progress, n_past = 14301, n_tokens = 337, progress = 0.023495
slot update_slots: id  3 | task 1964 | prompt done, n_past = 14301, n_tokens = 337
slot      release: id  2 | task 1962 | stop processing: n_past = 483, truncated = 0
slot print_timing: id  2 | task 1962 |
prompt eval time =     279.07 ms /   202 tokens (    1.38 ms per token,   723.84 tokens per second)
       eval time =   12650.91 ms /    23 tokens (  550.04 ms per token,     1.82 tokens per second)
      total time =   12929.98 ms /   225 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
slot      release: id  3 | task 1964 | stop processing: n_past = 14345, truncated = 0
slot print_timing: id  3 | task 1964 |
prompt eval time =     933.07 ms /   336 tokens (    2.78 ms per token,   360.10 tokens per second)
       eval time =   13704.92 ms /    45 tokens (  304.55 ms per token,     3.28 tokens per second)
      total time =   14637.99 ms /   381 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  1 | task 2010 | processing task
slot update_slots: id  1 | task 2010 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 656
slot update_slots: id  1 | task 2010 | kv cache rm [2, end)
slot update_slots: id  1 | task 2010 | prompt processing progress, n_past = 656, n_tokens = 654, progress = 0.996951
slot update_slots: id  1 | task 2010 | prompt done, n_past = 656, n_tokens = 654
slot      release: id  1 | task 2010 | stop processing: n_past = 657, truncated = 0
slot print_timing: id  1 | task 2010 |
prompt eval time =     759.08 ms /   654 tokens (    1.16 ms per token,   861.57 tokens per second)
       eval time =      80.50 ms /     2 tokens (   40.25 ms per token,    24.85 tokens per second)
      total time =     839.57 ms /   656 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  2 | task 2013 | processing task
slot update_slots: id  2 | task 2013 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 3122
slot update_slots: id  2 | task 2013 | kv cache rm [258, end)
slot update_slots: id  2 | task 2013 | prompt processing progress, n_past = 2306, n_tokens = 2048, progress = 0.655990
srv  params_from_: Chat format: Generic
slot launch_slot_: id  3 | task 2015 | processing task
slot update_slots: id  2 | task 2013 | kv cache rm [2306, end)
slot update_slots: id  2 | task 2013 | prompt processing progress, n_past = 3122, n_tokens = 816, progress = 0.917361
slot update_slots: id  2 | task 2013 | prompt done, n_past = 3122, n_tokens = 816
slot update_slots: id  3 | task 2015 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 17474
slot update_slots: id  3 | task 2015 | kv cache rm [14306, end)
slot update_slots: id  3 | task 2015 | prompt processing progress, n_past = 15538, n_tokens = 2048, progress = 0.070505
D:\llamacpp-cuda\ggml\src\ggml-cuda\fattn-common.cuh:748: GGML_ASSERT(ggml_is_contiguously_allocated(K)) failed

The command line is almost the same except the server port number.

@ggerganov
Copy link
Member Author

@rujialiu Thanks for reporting these - very useful information. The CUDA backend needs to be updated to handle quantized KV cache with the new split KV mode (#14756).

Btw, are you interested in writing a tutorial for using Claude Code with llama-server? Even if it is something very simple, I think it would be useful information, and I'll be interested in trying it.

@rujialiu
Copy link

@rujialiu Thanks for reporting these - very useful information. The CUDA backend needs to be updated to handle quantized KV cache with the new split KV mode (#14756).

Btw, are you interested in writing a tutorial for using Claude Code with llama-server? Even if it is something very simple, I think it would be useful information, and I'll be interested in trying it.

Yes, I'd love to share my experience! Recently I tried a lot of approaches to agentic coding with llama-server including cline, roo-code and OpenHands. Maybe a more general topic like "agentic coding with llama-server"? Probably more people would be interested.

@ggerganov
Copy link
Member Author

@rujialiu #14820 should fix the issue with quantized split KV cache. Let me know if you give it a try.

@rujialiu
Copy link

@rujialiu #14820 should fix the issue with quantized split KV cache. Let me know if you give it a try.

Since the code is not merged, I manually tried that branch, with the same command line (q4_0 kv cache):

  • The overall test process last for ~1 hour (with Claude Code 1.0.56)
  • At startup, VRAM usage is slightly higher (10.4GB) than master (9.8GB) but still much lower than LLAMA_SET_ROWS=0 (10.9GB). However, after a few minutes, it rises to 10.9GB (but stopped growing further)
  • Seems to have a much more stable token/s compared master (i.e. this PR), which is theoretically improving upon LLAMA_SET_ROWS=0 but not quite obvious in my use case.

I also tested q8 kv cache for a while (but a few minutes only), too. No crash/assertion failed.

BTW: I'm planning to try Qwen3-Coder-480B-A35B-Instruct soon, try to tune what to offload to CPU and see how your new changes work. @ggerganov

@rujialiu
Copy link

rujialiu commented Jul 23, 2025

ok, i see 14820 is replaced by #14822
do you want me to test that? @ggerganov

@ggerganov
Copy link
Member Author

Yes, but make sure to apply the patch that I posted there to remove the extra alloc. Otherwise it will assert.

@rujialiu
Copy link

rujialiu commented Jul 23, 2025

Yes, but make sure to apply the patch that I posted there to remove the extra alloc. Otherwise it will assert.

Unfortunately, it exited without any message after a few minutes (the session is ~30 minutes but I think there is a gap of >15 minutes without access). At startup, VRAM usage is 9.8GB, and when it exited, it's 9.95GB:

srv  params_from_: Chat format: Content-only
slot launch_slot_: id  3 | task 0 | processing task
slot update_slots: id  3 | task 0 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 111
slot update_slots: id  3 | task 0 | kv cache rm [0, end)
slot update_slots: id  3 | task 0 | prompt processing progress, n_past = 111, n_tokens = 111, progress = 1.000000
slot update_slots: id  3 | task 0 | prompt done, n_past = 111, n_tokens = 111
srv  params_from_: Chat format: Generic
slot launch_slot_: id  2 | task 2 | processing task
slot update_slots: id  2 | task 2 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 13845
slot update_slots: id  2 | task 2 | kv cache rm [0, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 2047, n_tokens = 2048, progress = 0.147851
slot update_slots: id  2 | task 2 | kv cache rm [2047, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 4094, n_tokens = 2048, progress = 0.295702
slot update_slots: id  2 | task 2 | kv cache rm [4094, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 6141, n_tokens = 2048, progress = 0.443554
slot update_slots: id  2 | task 2 | kv cache rm [6141, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 8188, n_tokens = 2048, progress = 0.591405
slot update_slots: id  2 | task 2 | kv cache rm [8188, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 10235, n_tokens = 2048, progress = 0.739256
slot update_slots: id  2 | task 2 | kv cache rm [10235, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 12282, n_tokens = 2048, progress = 0.887107
slot update_slots: id  2 | task 2 | kv cache rm [12282, end)
slot update_slots: id  2 | task 2 | prompt processing progress, n_past = 13845, n_tokens = 1564, progress = 1.000000
slot update_slots: id  2 | task 2 | prompt done, n_past = 13845, n_tokens = 1564
slot      release: id  3 | task 0 | stop processing: n_past = 125, truncated = 0
slot print_timing: id  3 | task 0 |
prompt eval time =     421.31 ms /   111 tokens (    3.80 ms per token,   263.47 tokens per second)
       eval time =   17158.12 ms /    15 tokens ( 1143.87 ms per token,     0.87 tokens per second)
      total time =   17579.42 ms /   126 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
slot      release: id  2 | task 2 | stop processing: n_past = 13916, truncated = 0
slot print_timing: id  2 | task 2 |
prompt eval time =   13265.47 ms / 13845 tokens (    0.96 ms per token,  1043.69 tokens per second)
       eval time =    9388.10 ms /    72 tokens (  130.39 ms per token,     7.67 tokens per second)
      total time =   22653.57 ms / 13917 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 81 | processing task
slot update_slots: id  1 | task 81 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 10546
slot update_slots: id  1 | task 81 | kv cache rm [0, end)
slot update_slots: id  1 | task 81 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.194197
slot update_slots: id  1 | task 81 | kv cache rm [2048, end)
slot update_slots: id  1 | task 81 | prompt processing progress, n_past = 4096, n_tokens = 2048, progress = 0.388394
slot update_slots: id  1 | task 81 | kv cache rm [4096, end)
slot update_slots: id  1 | task 81 | prompt processing progress, n_past = 6144, n_tokens = 2048, progress = 0.582591
slot update_slots: id  1 | task 81 | kv cache rm [6144, end)
slot update_slots: id  1 | task 81 | prompt processing progress, n_past = 8192, n_tokens = 2048, progress = 0.776787
slot update_slots: id  1 | task 81 | kv cache rm [8192, end)
slot update_slots: id  1 | task 81 | prompt processing progress, n_past = 10240, n_tokens = 2048, progress = 0.970984
slot update_slots: id  1 | task 81 | kv cache rm [10240, end)
slot update_slots: id  1 | task 81 | prompt processing progress, n_past = 10546, n_tokens = 306, progress = 1.000000
slot update_slots: id  1 | task 81 | prompt done, n_past = 10546, n_tokens = 306
slot      release: id  1 | task 81 | stop processing: n_past = 10572, truncated = 0
slot print_timing: id  1 | task 81 |
prompt eval time =   10896.56 ms / 10546 tokens (    1.03 ms per token,   967.83 tokens per second)
       eval time =    2159.07 ms /    27 tokens (   79.97 ms per token,    12.51 tokens per second)
      total time =   13055.63 ms / 10573 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 114 | processing task
slot update_slots: id  1 | task 114 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 10783
slot update_slots: id  1 | task 114 | kv cache rm [10546, end)
slot update_slots: id  1 | task 114 | prompt processing progress, n_past = 10783, n_tokens = 237, progress = 0.021979
slot update_slots: id  1 | task 114 | prompt done, n_past = 10783, n_tokens = 237
slot      release: id  1 | task 114 | stop processing: n_past = 10826, truncated = 0
slot print_timing: id  1 | task 114 |
prompt eval time =     323.34 ms /   237 tokens (    1.36 ms per token,   732.98 tokens per second)
       eval time =    3567.18 ms /    44 tokens (   81.07 ms per token,    12.33 tokens per second)
      total time =    3890.52 ms /   281 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 159 | processing task
slot update_slots: id  1 | task 159 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 11022
slot update_slots: id  1 | task 159 | kv cache rm [10788, end)
slot update_slots: id  1 | task 159 | prompt processing progress, n_past = 11022, n_tokens = 234, progress = 0.021230
slot update_slots: id  1 | task 159 | prompt done, n_past = 11022, n_tokens = 234
slot      release: id  1 | task 159 | stop processing: n_past = 11070, truncated = 0
slot print_timing: id  1 | task 159 |
prompt eval time =     310.21 ms /   234 tokens (    1.33 ms per token,   754.32 tokens per second)
       eval time =    3981.41 ms /    49 tokens (   81.25 ms per token,    12.31 tokens per second)
      total time =    4291.62 ms /   283 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 209 | processing task
slot update_slots: id  1 | task 209 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 11354
slot update_slots: id  1 | task 209 | kv cache rm [11027, end)
slot update_slots: id  1 | task 209 | prompt processing progress, n_past = 11354, n_tokens = 327, progress = 0.028800
slot update_slots: id  1 | task 209 | prompt done, n_past = 11354, n_tokens = 327
slot      release: id  1 | task 209 | stop processing: n_past = 11402, truncated = 0
slot print_timing: id  1 | task 209 |
prompt eval time =     385.21 ms /   327 tokens (    1.18 ms per token,   848.89 tokens per second)
       eval time =    4009.58 ms /    49 tokens (   81.83 ms per token,    12.22 tokens per second)
      total time =    4394.79 ms /   376 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 259 | processing task
slot update_slots: id  1 | task 259 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 11845
slot update_slots: id  1 | task 259 | kv cache rm [11359, end)
slot update_slots: id  1 | task 259 | prompt processing progress, n_past = 11845, n_tokens = 486, progress = 0.041030
slot update_slots: id  1 | task 259 | prompt done, n_past = 11845, n_tokens = 486
slot      release: id  1 | task 259 | stop processing: n_past = 11897, truncated = 0
slot print_timing: id  1 | task 259 |
prompt eval time =     536.58 ms /   486 tokens (    1.10 ms per token,   905.74 tokens per second)
       eval time =    4352.62 ms /    53 tokens (   82.12 ms per token,    12.18 tokens per second)
      total time =    4889.19 ms /   539 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 313 | processing task
slot update_slots: id  1 | task 313 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 12306
slot update_slots: id  1 | task 313 | kv cache rm [11850, end)
slot update_slots: id  1 | task 313 | prompt processing progress, n_past = 12306, n_tokens = 456, progress = 0.037055
slot update_slots: id  1 | task 313 | prompt done, n_past = 12306, n_tokens = 456
slot      release: id  1 | task 313 | stop processing: n_past = 12357, truncated = 0
slot print_timing: id  1 | task 313 |
prompt eval time =     526.02 ms /   456 tokens (    1.15 ms per token,   866.88 tokens per second)
       eval time =    4241.52 ms /    52 tokens (   81.57 ms per token,    12.26 tokens per second)
      total time =    4767.54 ms /   508 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 366 | processing task
slot update_slots: id  1 | task 366 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 13884
slot update_slots: id  1 | task 366 | kv cache rm [12311, end)
slot update_slots: id  1 | task 366 | prompt processing progress, n_past = 13884, n_tokens = 1573, progress = 0.113296
slot update_slots: id  1 | task 366 | prompt done, n_past = 13884, n_tokens = 1573
slot      release: id  1 | task 366 | stop processing: n_past = 13935, truncated = 0
slot print_timing: id  1 | task 366 |
prompt eval time =    1731.19 ms /  1573 tokens (    1.10 ms per token,   908.62 tokens per second)
       eval time =    4324.18 ms /    52 tokens (   83.16 ms per token,    12.03 tokens per second)
      total time =    6055.37 ms /  1625 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 419 | processing task
slot update_slots: id  1 | task 419 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 16622
slot update_slots: id  1 | task 419 | kv cache rm [13889, end)
slot update_slots: id  1 | task 419 | prompt processing progress, n_past = 15937, n_tokens = 2048, progress = 0.123210
slot update_slots: id  1 | task 419 | kv cache rm [15937, end)
slot update_slots: id  1 | task 419 | prompt processing progress, n_past = 16622, n_tokens = 685, progress = 0.164421
slot update_slots: id  1 | task 419 | prompt done, n_past = 16622, n_tokens = 685
slot      release: id  1 | task 419 | stop processing: n_past = 16672, truncated = 0
slot print_timing: id  1 | task 419 |
prompt eval time =    3083.84 ms /  2733 tokens (    1.13 ms per token,   886.23 tokens per second)
       eval time =    4766.03 ms /    51 tokens (   93.45 ms per token,    10.70 tokens per second)
      total time =    7849.87 ms /  2784 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 472 | processing task
slot update_slots: id  1 | task 472 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 16941
slot update_slots: id  1 | task 472 | kv cache rm [16627, end)
slot update_slots: id  1 | task 472 | prompt processing progress, n_past = 16941, n_tokens = 314, progress = 0.018535
slot update_slots: id  1 | task 472 | prompt done, n_past = 16941, n_tokens = 314
slot      release: id  1 | task 472 | stop processing: n_past = 16991, truncated = 0
slot print_timing: id  1 | task 472 |
prompt eval time =     411.57 ms /   314 tokens (    1.31 ms per token,   762.93 tokens per second)
       eval time =    4823.22 ms /    51 tokens (   94.57 ms per token,    10.57 tokens per second)
      total time =    5234.79 ms /   365 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 524 | processing task
slot update_slots: id  1 | task 524 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 17512
slot update_slots: id  1 | task 524 | kv cache rm [16946, end)
slot update_slots: id  1 | task 524 | prompt processing progress, n_past = 17512, n_tokens = 566, progress = 0.032321
slot update_slots: id  1 | task 524 | prompt done, n_past = 17512, n_tokens = 566
slot      release: id  1 | task 524 | stop processing: n_past = 17567, truncated = 0
slot print_timing: id  1 | task 524 |
prompt eval time =     747.64 ms /   566 tokens (    1.32 ms per token,   757.05 tokens per second)
       eval time =    5473.98 ms /    56 tokens (   97.75 ms per token,    10.23 tokens per second)
      total time =    6221.62 ms /   622 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 581 | processing task
slot update_slots: id  1 | task 581 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 17657
slot update_slots: id  1 | task 581 | kv cache rm [17517, end)
slot update_slots: id  1 | task 581 | prompt processing progress, n_past = 17657, n_tokens = 140, progress = 0.007929
slot update_slots: id  1 | task 581 | prompt done, n_past = 17657, n_tokens = 140
slot      release: id  1 | task 581 | stop processing: n_past = 17698, truncated = 0
slot print_timing: id  1 | task 581 |
prompt eval time =     269.36 ms /   140 tokens (    1.92 ms per token,   519.74 tokens per second)
       eval time =    4112.60 ms /    42 tokens (   97.92 ms per token,    10.21 tokens per second)
      total time =    4381.97 ms /   182 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 624 | processing task
slot update_slots: id  1 | task 624 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 18258
slot update_slots: id  1 | task 624 | kv cache rm [17662, end)
slot update_slots: id  1 | task 624 | prompt processing progress, n_past = 18258, n_tokens = 596, progress = 0.032643
slot update_slots: id  1 | task 624 | prompt done, n_past = 18258, n_tokens = 596
slot      release: id  1 | task 624 | stop processing: n_past = 18331, truncated = 0
slot print_timing: id  1 | task 624 |
prompt eval time =     793.18 ms /   596 tokens (    1.33 ms per token,   751.40 tokens per second)
       eval time =    7408.72 ms /    74 tokens (  100.12 ms per token,     9.99 tokens per second)
      total time =    8201.90 ms /   670 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 699 | processing task
slot update_slots: id  0 | task 699 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 679
slot update_slots: id  0 | task 699 | kv cache rm [0, end)
slot update_slots: id  0 | task 699 | prompt processing progress, n_past = 679, n_tokens = 679, progress = 1.000000
slot update_slots: id  0 | task 699 | prompt done, n_past = 679, n_tokens = 679
slot      release: id  0 | task 699 | stop processing: n_past = 682, truncated = 0
slot print_timing: id  0 | task 699 |
prompt eval time =     874.46 ms /   679 tokens (    1.29 ms per token,   776.48 tokens per second)
       eval time =     283.10 ms /     4 tokens (   70.77 ms per token,    14.13 tokens per second)
      total time =    1157.56 ms /   683 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  3 | task 704 | processing task
slot update_slots: id  3 | task 704 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 419
slot update_slots: id  3 | task 704 | kv cache rm [2, end)
slot update_slots: id  3 | task 704 | prompt processing progress, n_past = 419, n_tokens = 417, progress = 0.995227
slot update_slots: id  3 | task 704 | prompt done, n_past = 419, n_tokens = 417
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 706 | processing task
slot update_slots: id  1 | task 706 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 18541
slot update_slots: id  1 | task 706 | kv cache rm [18263, end)
slot update_slots: id  1 | task 706 | prompt processing progress, n_past = 18541, n_tokens = 279, progress = 0.014994
slot update_slots: id  1 | task 706 | prompt done, n_past = 18541, n_tokens = 279
slot      release: id  3 | task 704 | stop processing: n_past = 432, truncated = 0
slot print_timing: id  3 | task 704 |
prompt eval time =     714.35 ms /   417 tokens (    1.71 ms per token,   583.75 tokens per second)
       eval time =    2876.78 ms /    14 tokens (  205.48 ms per token,     4.87 tokens per second)
      total time =    3591.14 ms /   431 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
slot      release: id  1 | task 706 | stop processing: n_past = 18615, truncated = 0
slot print_timing: id  1 | task 706 |
prompt eval time =     491.48 ms /   278 tokens (    1.77 ms per token,   565.64 tokens per second)
       eval time =    8789.93 ms /    75 tokens (  117.20 ms per token,     8.53 tokens per second)
      total time =    9281.41 ms /   353 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  1 | task 782 | processing task
slot update_slots: id  1 | task 782 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 18705
slot update_slots: id  1 | task 782 | kv cache rm [18546, end)
slot update_slots: id  1 | task 782 | prompt processing progress, n_past = 18705, n_tokens = 159, progress = 0.008500
slot update_slots: id  1 | task 782 | prompt done, n_past = 18705, n_tokens = 159
slot      release: id  1 | task 782 | stop processing: n_past = 19409, truncated = 0
slot print_timing: id  1 | task 782 |
prompt eval time =     274.58 ms /   159 tokens (    1.73 ms per token,   579.07 tokens per second)
       eval time =   74827.07 ms /   705 tokens (  106.14 ms per token,     9.42 tokens per second)
      total time =   75101.64 ms /   864 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Generic
slot launch_slot_: id  2 | task 1488 | processing task
slot update_slots: id  2 | task 1488 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 14763
slot update_slots: id  2 | task 1488 | kv cache rm [13845, end)
slot update_slots: id  2 | task 1488 | prompt processing progress, n_past = 14763, n_tokens = 918, progress = 0.062182
slot update_slots: id  2 | task 1488 | prompt done, n_past = 14763, n_tokens = 918
slot      release: id  2 | task 1488 | stop processing: n_past = 14891, truncated = 0
slot print_timing: id  2 | task 1488 |
prompt eval time =    1126.12 ms /   918 tokens (    1.23 ms per token,   815.19 tokens per second)
       eval time =   13434.76 ms /   129 tokens (  104.15 ms per token,     9.60 tokens per second)
      total time =   14560.87 ms /  1047 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 1618 | processing task
slot update_slots: id  0 | task 1618 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 83
slot update_slots: id  0 | task 1618 | kv cache rm [2, end)
slot update_slots: id  0 | task 1618 | prompt processing progress, n_past = 83, n_tokens = 81, progress = 0.975904
slot update_slots: id  0 | task 1618 | prompt done, n_past = 83, n_tokens = 81
srv  params_from_: Chat format: Generic
slot launch_slot_: id  2 | task 1620 | processing task
slot update_slots: id  2 | task 1620 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 14892
slot update_slots: id  2 | task 1620 | kv cache rm [14763, end)
slot update_slots: id  2 | task 1620 | prompt processing progress, n_past = 14892, n_tokens = 130, progress = 0.008662
slot update_slots: id  2 | task 1620 | prompt done, n_past = 14892, n_tokens = 130
slot      release: id  0 | task 1618 | stop processing: n_past = 96, truncated = 0
slot print_timing: id  0 | task 1618 |
prompt eval time =     401.52 ms /    81 tokens (    4.96 ms per token,   201.73 tokens per second)
       eval time =    2993.99 ms /    14 tokens (  213.86 ms per token,     4.68 tokens per second)
      total time =    3395.51 ms /    95 tokens
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
slot      release: id  2 | task 1620 | stop processing: n_past = 16677, truncated = 0
slot print_timing: id  2 | task 1620 |
prompt eval time =     389.01 ms /   129 tokens (    3.02 ms per token,   331.61 tokens per second)
       eval time =  189248.98 ms /  1786 tokens (  105.96 ms per token,     9.44 tokens per second)
      total time =  189637.99 ms /  1915 tokens

You can add some more debug prints temporarily in that branch so I can give you more information next time. @ggerganov

@rujialiu
Copy link

I'm unable to make it exit for the second time. I'm been using it till now (though not very intensively). One hour passed, and VRAM usage is now 10.07GB. I even explicitly asked cc to spawn parallel tasks (and can indeed see the requests are handled with different slots). So it looks like it's a very rare problem. Are you able to reproduce this with e.g. llama-parallel? @ggerganov (and thanks for folding the logs. I forgot that)

@ggerganov
Copy link
Member Author

llama-parallel runs fine on my end. I think the process exiting as you describe is probably not an issue of the llama-server, but something else on your system triggered it. Given that you don't observed it anymore, it seems like the changes are good.

@rujialiu
Copy link

Yes. Now I think might be related peak memory. I can see that the peak RAM usage is 18GB greater than its current usage. After doing some math, I realized that at the "peak RAM" time it's very close to OOM. I think llama-server rarely allocates memory after startup, but when it does, it can fail to allocate? Anyway, I agree that the changes are good.

@ggerganov
Copy link
Member Author

I believe when CUDA fails to allocate, it will print an error instead of silently exiting. Though not 100% sure.

@rujialiu
Copy link

I believe when CUDA fails to allocate, it will print an error instead of silently exiting. Though not 100% sure.

I actually mean CPU failed to allocate RAM, not CUDA. In my use case, most RAM is used by kv cache I think (I used -nkvo). It looks like it not only allocates memory, but frees a lot of memory after some idle time?

@prd-tuong-nguyen
Copy link

I conducted an experiment using the L40S GPU and the Gemma 27B model (Q_4), and I noticed an unusual increase in latency.

With 1 CCU, the latency consistently hovered around 600ms. However, after experiencing a high load with 32 CCU, I checked the latency again using 1 CCU, and it had risen to approximately 1500ms.

Do you have any insights on this issue?

Here is my command:

Start command: /app/llama-server --port 8080 --host 0.0.0.0 --n-gpu-layers 100000000 --flash-attn --no-warmup --jinja --lora-init-without-apply --defrag-thold 0.1 --n-gpu-layers 63 --parallel 8 --flash-attn --ctx-size 32768 --ubatch-size 2048 --batch-size 2048 --metrics --slots --no-mmap --threads 16 --threads-batch 16

@ggerganov
Copy link
Member Author

@prd-tuong-nguyen Most likely you didn't set LLAMA_SET_ROWS=1.

@prd-tuong-nguyen
Copy link

@ggerganov I've already set it, but I've noticed that after a high load, my model consistently returns an empty response while the server logs indicate that it generates the maximum number of tokens (which is set to 36 in my configuration), and that’s why the latency increases. It may be a bug.

@ggerganov
Copy link
Member Author

It's possible that you run out of context - try to add --no-context-shift. If the issue persists, open a separate issue with detailed repro steps and logs.

@guokoni

This comment was marked as spam.

1 similar comment
@guokoni

This comment was marked as spam.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning hot Something that is hot Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants