Skip to content

Commit c8c07d6

Browse files
ngxsonggerganov
andauthored
llama : fix empty batch causing llama_batch_allocr to crash (ggml-org#9966)
* llama : fix empty batch cause llama_batch_allocr to crash * move batch_allocr inside decode/encode_internal * fix build * add GGML_ASSERT * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 19d900a commit c8c07d6

File tree

1 file changed

+67
-61
lines changed

1 file changed

+67
-61
lines changed

src/llama.cpp

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5177,6 +5177,57 @@ struct llama_model_loader {
51775177
}
51785178
};
51795179

5180+
// temporary allocate memory for the input batch if needed
5181+
static const llama_seq_id batch_default_seq_id = 0;
5182+
struct llama_batch_allocr {
5183+
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
5184+
std::vector<llama_pos> pos;
5185+
std::vector<int32_t> n_seq_id;
5186+
std::vector<llama_seq_id *> seq_id;
5187+
std::vector<int8_t> logits;
5188+
struct llama_batch batch;
5189+
// optionally fulfill the batch returned by llama_batch_get_one
5190+
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
5191+
batch = in_batch;
5192+
GGML_ASSERT(batch.n_tokens > 0);
5193+
if (!batch.pos) {
5194+
// determine the last position in KV cache
5195+
llama_pos last_pos = -1;
5196+
for (const auto & cell : ctx.kv_self.cells) {
5197+
if (cell.has_seq_id(batch_default_seq_id)) {
5198+
last_pos = std::max(last_pos, cell.pos);
5199+
}
5200+
}
5201+
last_pos++; // next position
5202+
pos.resize(batch.n_tokens);
5203+
for (int32_t i = 0; i < batch.n_tokens; i++) {
5204+
pos[i] = i+last_pos;
5205+
}
5206+
batch.pos = pos.data();
5207+
}
5208+
if (!batch.n_seq_id) {
5209+
n_seq_id.resize(batch.n_tokens);
5210+
for (int32_t i = 0; i < batch.n_tokens; i++) {
5211+
n_seq_id[i] = seq_id_0.size();
5212+
}
5213+
batch.n_seq_id = n_seq_id.data();
5214+
}
5215+
if (!batch.seq_id) {
5216+
seq_id.resize(batch.n_tokens + 1);
5217+
seq_id[batch.n_tokens] = NULL;
5218+
for (int32_t i = 0; i < batch.n_tokens; i++) {
5219+
seq_id[i] = seq_id_0.data();
5220+
}
5221+
batch.seq_id = seq_id.data();
5222+
}
5223+
if (!batch.logits) {
5224+
logits.resize(batch.n_tokens);
5225+
logits[logits.size() - 1] = true;
5226+
batch.logits = logits.data();
5227+
}
5228+
}
5229+
};
5230+
51805231
template<>
51815232
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
51825233
uint32_t tmp;
@@ -17095,16 +17146,20 @@ static void llama_graph_compute(
1709517146
//
1709617147
static int llama_decode_internal(
1709717148
llama_context & lctx,
17098-
llama_batch batch) {
17149+
llama_batch inp_batch) {
1709917150

1710017151
lctx.is_encoding = false;
17101-
const uint32_t n_tokens_all = batch.n_tokens;
1710217152

17103-
if (n_tokens_all == 0) {
17153+
if (inp_batch.n_tokens == 0) {
1710417154
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1710517155
return -1;
1710617156
}
1710717157

17158+
// temporary allocate memory for the input batch if needed
17159+
llama_batch_allocr batch_allocr(lctx, inp_batch);
17160+
const llama_batch & batch = batch_allocr.batch;
17161+
const uint32_t n_tokens_all = batch.n_tokens;
17162+
1710817163
const auto & model = lctx.model;
1710917164
const auto & hparams = model.hparams;
1711017165
const auto & cparams = lctx.cparams;
@@ -17409,17 +17464,20 @@ static int llama_decode_internal(
1740917464
//
1741017465
static int llama_encode_internal(
1741117466
llama_context & lctx,
17412-
llama_batch batch) {
17467+
llama_batch inp_batch) {
1741317468

1741417469
lctx.is_encoding = true;
1741517470

17416-
const uint32_t n_tokens = batch.n_tokens;
17417-
17418-
if (n_tokens == 0) {
17471+
if (inp_batch.n_tokens == 0) {
1741917472
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1742017473
return -1;
1742117474
}
1742217475

17476+
// temporary allocate memory for the input batch if needed
17477+
llama_batch_allocr batch_allocr(lctx, inp_batch);
17478+
const llama_batch & batch = batch_allocr.batch;
17479+
const uint32_t n_tokens = batch.n_tokens;
17480+
1742317481
const auto & model = lctx.model;
1742417482
const auto & hparams = model.hparams;
1742517483
const auto & cparams = lctx.cparams;
@@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) {
2109021148
if (batch.logits) free(batch.logits);
2109121149
}
2109221150

21093-
// temporary allocate memory for the input batch if needed
21094-
static const llama_seq_id batch_default_seq_id = 0;
21095-
struct llama_batch_allocr {
21096-
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
21097-
std::vector<llama_pos> pos;
21098-
std::vector<int32_t> n_seq_id;
21099-
std::vector<llama_seq_id *> seq_id;
21100-
std::vector<int8_t> logits;
21101-
struct llama_batch batch;
21102-
// optionally fulfill the batch returned by llama_batch_get_one
21103-
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
21104-
batch = in_batch;
21105-
if (!batch.pos) {
21106-
// determine the last position in KV cache
21107-
llama_pos last_pos = -1;
21108-
for (const auto & cell : ctx->kv_self.cells) {
21109-
if (cell.has_seq_id(batch_default_seq_id)) {
21110-
last_pos = std::max(last_pos, cell.pos);
21111-
}
21112-
}
21113-
last_pos++; // next position
21114-
pos.resize(batch.n_tokens);
21115-
for (int32_t i = 0; i < batch.n_tokens; i++) {
21116-
pos[i] = i+last_pos;
21117-
}
21118-
batch.pos = pos.data();
21119-
}
21120-
if (!batch.n_seq_id) {
21121-
n_seq_id.resize(batch.n_tokens);
21122-
for (int32_t i = 0; i < batch.n_tokens; i++) {
21123-
n_seq_id[i] = seq_id_0.size();
21124-
}
21125-
batch.n_seq_id = n_seq_id.data();
21126-
}
21127-
if (!batch.seq_id) {
21128-
seq_id.resize(batch.n_tokens + 1);
21129-
seq_id[batch.n_tokens] = NULL;
21130-
for (int32_t i = 0; i < batch.n_tokens; i++) {
21131-
seq_id[i] = seq_id_0.data();
21132-
}
21133-
batch.seq_id = seq_id.data();
21134-
}
21135-
if (!batch.logits) {
21136-
logits.resize(batch.n_tokens);
21137-
logits[logits.size() - 1] = true;
21138-
batch.logits = logits.data();
21139-
}
21140-
}
21141-
};
21142-
2114321151
int32_t llama_encode(
2114421152
struct llama_context * ctx,
2114521153
struct llama_batch batch) {
21146-
llama_batch_allocr batch_allocr(ctx, batch);
21147-
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
21154+
const int ret = llama_encode_internal(*ctx, batch);
2114821155
if (ret != 0) {
2114921156
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2115021157
}
@@ -21155,8 +21162,7 @@ int32_t llama_encode(
2115521162
int32_t llama_decode(
2115621163
struct llama_context * ctx,
2115721164
struct llama_batch batch) {
21158-
llama_batch_allocr batch_allocr(ctx, batch);
21159-
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
21165+
const int ret = llama_decode_internal(*ctx, batch);
2116021166
if (ret != 0) {
2116121167
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2116221168
}

0 commit comments

Comments
 (0)