@@ -5177,6 +5177,57 @@ struct llama_model_loader {
5177
5177
}
5178
5178
};
5179
5179
5180
+ // temporary allocate memory for the input batch if needed
5181
+ static const llama_seq_id batch_default_seq_id = 0;
5182
+ struct llama_batch_allocr {
5183
+ std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
5184
+ std::vector<llama_pos> pos;
5185
+ std::vector<int32_t> n_seq_id;
5186
+ std::vector<llama_seq_id *> seq_id;
5187
+ std::vector<int8_t> logits;
5188
+ struct llama_batch batch;
5189
+ // optionally fulfill the batch returned by llama_batch_get_one
5190
+ llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
5191
+ batch = in_batch;
5192
+ GGML_ASSERT(batch.n_tokens > 0);
5193
+ if (!batch.pos) {
5194
+ // determine the last position in KV cache
5195
+ llama_pos last_pos = -1;
5196
+ for (const auto & cell : ctx.kv_self.cells) {
5197
+ if (cell.has_seq_id(batch_default_seq_id)) {
5198
+ last_pos = std::max(last_pos, cell.pos);
5199
+ }
5200
+ }
5201
+ last_pos++; // next position
5202
+ pos.resize(batch.n_tokens);
5203
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5204
+ pos[i] = i+last_pos;
5205
+ }
5206
+ batch.pos = pos.data();
5207
+ }
5208
+ if (!batch.n_seq_id) {
5209
+ n_seq_id.resize(batch.n_tokens);
5210
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5211
+ n_seq_id[i] = seq_id_0.size();
5212
+ }
5213
+ batch.n_seq_id = n_seq_id.data();
5214
+ }
5215
+ if (!batch.seq_id) {
5216
+ seq_id.resize(batch.n_tokens + 1);
5217
+ seq_id[batch.n_tokens] = NULL;
5218
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5219
+ seq_id[i] = seq_id_0.data();
5220
+ }
5221
+ batch.seq_id = seq_id.data();
5222
+ }
5223
+ if (!batch.logits) {
5224
+ logits.resize(batch.n_tokens);
5225
+ logits[logits.size() - 1] = true;
5226
+ batch.logits = logits.data();
5227
+ }
5228
+ }
5229
+ };
5230
+
5180
5231
template<>
5181
5232
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
5182
5233
uint32_t tmp;
@@ -17095,16 +17146,20 @@ static void llama_graph_compute(
17095
17146
//
17096
17147
static int llama_decode_internal(
17097
17148
llama_context & lctx,
17098
- llama_batch batch ) {
17149
+ llama_batch inp_batch ) {
17099
17150
17100
17151
lctx.is_encoding = false;
17101
- const uint32_t n_tokens_all = batch.n_tokens;
17102
17152
17103
- if (n_tokens_all == 0) {
17153
+ if (inp_batch.n_tokens == 0) {
17104
17154
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
17105
17155
return -1;
17106
17156
}
17107
17157
17158
+ // temporary allocate memory for the input batch if needed
17159
+ llama_batch_allocr batch_allocr(lctx, inp_batch);
17160
+ const llama_batch & batch = batch_allocr.batch;
17161
+ const uint32_t n_tokens_all = batch.n_tokens;
17162
+
17108
17163
const auto & model = lctx.model;
17109
17164
const auto & hparams = model.hparams;
17110
17165
const auto & cparams = lctx.cparams;
@@ -17409,17 +17464,20 @@ static int llama_decode_internal(
17409
17464
//
17410
17465
static int llama_encode_internal(
17411
17466
llama_context & lctx,
17412
- llama_batch batch ) {
17467
+ llama_batch inp_batch ) {
17413
17468
17414
17469
lctx.is_encoding = true;
17415
17470
17416
- const uint32_t n_tokens = batch.n_tokens;
17417
-
17418
- if (n_tokens == 0) {
17471
+ if (inp_batch.n_tokens == 0) {
17419
17472
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
17420
17473
return -1;
17421
17474
}
17422
17475
17476
+ // temporary allocate memory for the input batch if needed
17477
+ llama_batch_allocr batch_allocr(lctx, inp_batch);
17478
+ const llama_batch & batch = batch_allocr.batch;
17479
+ const uint32_t n_tokens = batch.n_tokens;
17480
+
17423
17481
const auto & model = lctx.model;
17424
17482
const auto & hparams = model.hparams;
17425
17483
const auto & cparams = lctx.cparams;
@@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) {
21090
21148
if (batch.logits) free(batch.logits);
21091
21149
}
21092
21150
21093
- // temporary allocate memory for the input batch if needed
21094
- static const llama_seq_id batch_default_seq_id = 0;
21095
- struct llama_batch_allocr {
21096
- std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
21097
- std::vector<llama_pos> pos;
21098
- std::vector<int32_t> n_seq_id;
21099
- std::vector<llama_seq_id *> seq_id;
21100
- std::vector<int8_t> logits;
21101
- struct llama_batch batch;
21102
- // optionally fulfill the batch returned by llama_batch_get_one
21103
- llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
21104
- batch = in_batch;
21105
- if (!batch.pos) {
21106
- // determine the last position in KV cache
21107
- llama_pos last_pos = -1;
21108
- for (const auto & cell : ctx->kv_self.cells) {
21109
- if (cell.has_seq_id(batch_default_seq_id)) {
21110
- last_pos = std::max(last_pos, cell.pos);
21111
- }
21112
- }
21113
- last_pos++; // next position
21114
- pos.resize(batch.n_tokens);
21115
- for (int32_t i = 0; i < batch.n_tokens; i++) {
21116
- pos[i] = i+last_pos;
21117
- }
21118
- batch.pos = pos.data();
21119
- }
21120
- if (!batch.n_seq_id) {
21121
- n_seq_id.resize(batch.n_tokens);
21122
- for (int32_t i = 0; i < batch.n_tokens; i++) {
21123
- n_seq_id[i] = seq_id_0.size();
21124
- }
21125
- batch.n_seq_id = n_seq_id.data();
21126
- }
21127
- if (!batch.seq_id) {
21128
- seq_id.resize(batch.n_tokens + 1);
21129
- seq_id[batch.n_tokens] = NULL;
21130
- for (int32_t i = 0; i < batch.n_tokens; i++) {
21131
- seq_id[i] = seq_id_0.data();
21132
- }
21133
- batch.seq_id = seq_id.data();
21134
- }
21135
- if (!batch.logits) {
21136
- logits.resize(batch.n_tokens);
21137
- logits[logits.size() - 1] = true;
21138
- batch.logits = logits.data();
21139
- }
21140
- }
21141
- };
21142
-
21143
21151
int32_t llama_encode(
21144
21152
struct llama_context * ctx,
21145
21153
struct llama_batch batch) {
21146
- llama_batch_allocr batch_allocr(ctx, batch);
21147
- const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
21154
+ const int ret = llama_encode_internal(*ctx, batch);
21148
21155
if (ret != 0) {
21149
21156
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
21150
21157
}
@@ -21155,8 +21162,7 @@ int32_t llama_encode(
21155
21162
int32_t llama_decode(
21156
21163
struct llama_context * ctx,
21157
21164
struct llama_batch batch) {
21158
- llama_batch_allocr batch_allocr(ctx, batch);
21159
- const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
21165
+ const int ret = llama_decode_internal(*ctx, batch);
21160
21166
if (ret != 0) {
21161
21167
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
21162
21168
}
0 commit comments