@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
7207
7207
ggml_compute_forward_conv_2d_impl (params, src0, src1, dst, src0->type );
7208
7208
}
7209
7209
7210
+ // ggml_compute_forward_conv_3d
7211
+
7212
+ static void ggml_compute_forward_conv_3d_impl (const ggml_compute_params * params,
7213
+ const ggml_tensor * kernel,
7214
+ const ggml_tensor * src,
7215
+ ggml_tensor * dst,
7216
+ ggml_type kernel_type) {
7217
+
7218
+ GGML_ASSERT (ggml_is_contiguous (kernel));
7219
+ GGML_ASSERT (kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
7220
+ GGML_ASSERT (kernel->type == kernel_type);
7221
+
7222
+ const ggml_type_traits * traits = ggml_get_type_traits (kernel_type);
7223
+
7224
+ const int32_t s0 = dst->op_params [0 ];
7225
+ const int32_t s1 = dst->op_params [1 ];
7226
+ const int32_t s2 = dst->op_params [2 ];
7227
+ const int32_t p0 = dst->op_params [3 ];
7228
+ const int32_t p1 = dst->op_params [4 ];
7229
+ const int32_t p2 = dst->op_params [5 ];
7230
+ const int32_t d0 = dst->op_params [6 ];
7231
+ const int32_t d1 = dst->op_params [7 ];
7232
+ const int32_t d2 = dst->op_params [8 ];
7233
+ const int32_t c = dst->op_params [9 ];
7234
+ const int32_t n = dst->op_params [10 ];
7235
+ const int32_t oc = dst->op_params [11 ];
7236
+
7237
+ const int64_t src_w = src->ne [0 ];
7238
+ const int64_t src_h = src->ne [1 ];
7239
+ const int64_t src_d = src->ne [2 ];
7240
+ const int64_t knl_w = kernel->ne [0 ];
7241
+ const int64_t knl_h = kernel->ne [1 ];
7242
+ const int64_t knl_d = kernel->ne [2 ];
7243
+ const int64_t dst_w = dst->ne [0 ];
7244
+ const int64_t dst_h = dst->ne [1 ];
7245
+ const int64_t dst_d = dst->ne [2 ];
7246
+
7247
+ const float * src_data = (float *) src->data ;
7248
+ void * knl_data = kernel->data ;
7249
+ float * dst_data = (float *) dst->data ;
7250
+
7251
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
7252
+ const int64_t knl_n_total = knl_n_per_channel * c;
7253
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
7254
+
7255
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof (float );
7256
+ const int64_t batch_size = params->wsize / space_per_patch;
7257
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
7258
+ const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
7259
+
7260
+ GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
7261
+
7262
+ void * tmp = params->wdata ;
7263
+
7264
+ for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
7265
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
7266
+ const int64_t patch_end_batch = std::min (patch_start_batch + patches_per_batch, patch_total);
7267
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
7268
+
7269
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1 ) / params->nth ;
7270
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
7271
+ const int64_t patch_end = std::min (patch_start + patch_per_thread, patch_end_batch);
7272
+
7273
+ for (int64_t p = patch_start; p < patch_end; ++p) {
7274
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7275
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7276
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
7277
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
7278
+ const int64_t dst_y = p_in_depth / dst_w;
7279
+ const int64_t dst_x = p_in_depth % dst_w;
7280
+
7281
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size ;
7282
+
7283
+ for (int64_t ic = 0 ; ic < c; ++ic) {
7284
+ for (int64_t kz = 0 ; kz < knl_d; ++kz) {
7285
+ for (int64_t ky = 0 ; ky < knl_h; ++ky) {
7286
+ for (int64_t kx = 0 ; kx < knl_w; ++kx) {
7287
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
7288
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
7289
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
7290
+
7291
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
7292
+
7293
+ float src_val;
7294
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
7295
+ src_val = 0 .0f ;
7296
+ } else {
7297
+ const int64_t cn_idx = batch_idx * c + ic;
7298
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb [0 ] + sy*src->nb [1 ] + sz*src->nb [2 ] + cn_idx*src->nb [3 ]);
7299
+ src_val = *src_ptr;
7300
+ }
7301
+
7302
+ char * element_ptr = dst_row + dst_idx * traits->type_size ;
7303
+ if (kernel_type == GGML_TYPE_F32) {
7304
+ *(float *)element_ptr = src_val;
7305
+ } else if (kernel_type == GGML_TYPE_F16) {
7306
+ *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16 (src_val);
7307
+ }
7308
+ }
7309
+ }
7310
+ }
7311
+ }
7312
+ }
7313
+
7314
+ ggml_barrier (params->threadpool );
7315
+
7316
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size );
7317
+ ggml_call_mul_mat (kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
7318
+
7319
+ ggml_barrier (params->threadpool );
7320
+
7321
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1 ) / params->nth ;
7322
+ const int64_t permute_start = params->ith * permute_per_thread;
7323
+ const int64_t permute_end = std::min (permute_start + permute_per_thread, patch_n_in_batch);
7324
+
7325
+ for (int64_t i = permute_start; i < permute_end; ++i) {
7326
+ const int64_t p = patch_start_batch + i;
7327
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7328
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7329
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
7330
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
7331
+ const int64_t dst_y = p_in_depth / dst_w;
7332
+ const int64_t dst_x = p_in_depth % dst_w;
7333
+
7334
+ for (int64_t ioc = 0 ; ioc < oc; ++ioc) {
7335
+ const float value = gemm_output[i * oc + ioc];
7336
+ const int64_t ocn_idx = batch_idx * oc + ioc;
7337
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb [0 ] + dst_y*dst->nb [1 ] + dst_z*dst->nb [2 ] + ocn_idx*dst->nb [3 ]);
7338
+ *dst_ptr = value;
7339
+ }
7340
+ }
7341
+ }
7342
+ }
7343
+
7344
+ void ggml_compute_forward_conv_3d (
7345
+ const ggml_compute_params * params,
7346
+ ggml_tensor * dst) {
7347
+ const ggml_tensor * src0 = dst->src [0 ];
7348
+ const ggml_tensor * src1 = dst->src [1 ];
7349
+ ggml_compute_forward_conv_3d_impl (params, src0, src1, dst, src0->type );
7350
+ }
7351
+
7210
7352
// ggml_compute_forward_conv_transpose_2d
7211
7353
7212
7354
void ggml_compute_forward_conv_transpose_2d (
0 commit comments