Skip to content

Commit 29ec316

Browse files
committed
add support for applying lora to quantized tensors
1 parent e91ce4f commit 29ec316

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

lora.hpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ struct LoraModel : public GGMLRunner {
1212
ModelLoader model_loader;
1313
bool load_failed = false;
1414
bool applied = false;
15+
std::vector<int> zero_index_vec = {0};
16+
ggml_tensor* zero_index = NULL;
1517

1618
LoraModel(ggml_backend_t backend,
1719
ggml_type wtype,
@@ -68,9 +70,19 @@ struct LoraModel : public GGMLRunner {
6870
return true;
6971
}
7072

73+
ggml_tensor* to_f32(ggml_context* ctx, ggml_tensor* a) {
74+
auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a));
75+
out = ggml_get_rows(ctx, out, zero_index);
76+
out = ggml_reshape(ctx, out, a);
77+
return out;
78+
}
79+
7180
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
7281
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
7382

83+
zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
84+
set_backend_tensor_data(zero_index, zero_index_vec.data());
85+
7486
std::set<std::string> applied_lora_tensors;
7587
for (auto it : model_tensors) {
7688
std::string k_tensor = it.first;
@@ -141,15 +153,16 @@ struct LoraModel : public GGMLRunner {
141153
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
142154
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
143155
ggml_tensor* final_weight;
144-
// if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
145-
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, weight->n_dims, weight->ne);
146-
// final_weight = ggml_cpy_inplace(compute_ctx, weight, final_weight);
147-
// final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
148-
// final_weight = ggml_cpy_inplace(compute_ctx, final_weight, weight);
149-
// } else {
150-
// final_weight = ggml_add_inplace(compute_ctx, weight, updown);
151-
// }
152-
final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
156+
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
157+
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne);
158+
// final_weight = ggml_cpy(compute_ctx, weight, final_weight);
159+
final_weight = to_f32(compute_ctx, weight);
160+
final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
161+
final_weight = ggml_cpy(compute_ctx, final_weight, weight);
162+
} else {
163+
final_weight = ggml_add_inplace(compute_ctx, weight, updown);
164+
}
165+
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
153166
ggml_build_forward_expand(gf, final_weight);
154167
}
155168

0 commit comments

Comments
 (0)