Skip to content

Commit 92f7f0a

Browse files
authored
ggml: add conv3d op (#15182)
* add conv3d * bump GGML_OP_COUNT
1 parent b1ab918 commit 92f7f0a

File tree

6 files changed

+345
-2
lines changed

6 files changed

+345
-2
lines changed

ggml/include/ggml.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ extern "C" {
512512
GGML_OP_IM2COL,
513513
GGML_OP_IM2COL_BACK,
514514
GGML_OP_CONV_2D,
515+
GGML_OP_CONV_3D,
515516
GGML_OP_CONV_2D_DW,
516517
GGML_OP_CONV_TRANSPOSE_2D,
517518
GGML_OP_POOL_1D,
@@ -1940,6 +1941,23 @@ extern "C" {
19401941
int d0, // dilation dimension 0
19411942
int d1); // dilation dimension 1
19421943

1944+
GGML_API struct ggml_tensor * ggml_conv_3d(
1945+
struct ggml_context * ctx,
1946+
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
1947+
struct ggml_tensor * b, // input [W, H, D, C * N]
1948+
int s0, // stride
1949+
int s1,
1950+
int s2,
1951+
int p0, // padding
1952+
int p1,
1953+
int p2,
1954+
int d0, // dilation
1955+
int d1,
1956+
int d2,
1957+
int n_channels,
1958+
int n_batch,
1959+
int n_channels_out);
1960+
19431961
enum ggml_op_pool {
19441962
GGML_OP_POOL_MAX,
19451963
GGML_OP_POOL_AVG,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18801880
{
18811881
ggml_compute_forward_conv_2d(params, tensor);
18821882
} break;
1883+
case GGML_OP_CONV_3D:
1884+
{
1885+
ggml_compute_forward_conv_3d(params, tensor);
1886+
} break;
18831887
case GGML_OP_CONV_2D_DW:
18841888
{
18851889
ggml_compute_forward_conv_2d_dw(params, tensor);
@@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22522256
case GGML_OP_IM2COL:
22532257
case GGML_OP_IM2COL_BACK:
22542258
case GGML_OP_CONV_2D:
2259+
case GGML_OP_CONV_3D:
22552260
case GGML_OP_CONV_2D_DW:
22562261
case GGML_OP_CONV_TRANSPOSE_1D:
22572262
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
27732778
}
27742779
} break;
27752780
case GGML_OP_CONV_2D:
2781+
case GGML_OP_CONV_3D:
27762782
{
27772783
cur = GGML_IM2COL_WORK_SIZE;
27782784
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
72077207
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
72087208
}
72097209

7210+
// ggml_compute_forward_conv_3d
7211+
7212+
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
7213+
const ggml_tensor * kernel,
7214+
const ggml_tensor * src,
7215+
ggml_tensor * dst,
7216+
ggml_type kernel_type) {
7217+
7218+
GGML_ASSERT(ggml_is_contiguous(kernel));
7219+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
7220+
GGML_ASSERT(kernel->type == kernel_type);
7221+
7222+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
7223+
7224+
const int32_t s0 = dst->op_params[0];
7225+
const int32_t s1 = dst->op_params[1];
7226+
const int32_t s2 = dst->op_params[2];
7227+
const int32_t p0 = dst->op_params[3];
7228+
const int32_t p1 = dst->op_params[4];
7229+
const int32_t p2 = dst->op_params[5];
7230+
const int32_t d0 = dst->op_params[6];
7231+
const int32_t d1 = dst->op_params[7];
7232+
const int32_t d2 = dst->op_params[8];
7233+
const int32_t c = dst->op_params[9];
7234+
const int32_t n = dst->op_params[10];
7235+
const int32_t oc = dst->op_params[11];
7236+
7237+
const int64_t src_w = src->ne[0];
7238+
const int64_t src_h = src->ne[1];
7239+
const int64_t src_d = src->ne[2];
7240+
const int64_t knl_w = kernel->ne[0];
7241+
const int64_t knl_h = kernel->ne[1];
7242+
const int64_t knl_d = kernel->ne[2];
7243+
const int64_t dst_w = dst->ne[0];
7244+
const int64_t dst_h = dst->ne[1];
7245+
const int64_t dst_d = dst->ne[2];
7246+
7247+
const float * src_data = (float *) src->data;
7248+
void * knl_data = kernel->data;
7249+
float * dst_data = (float *) dst->data;
7250+
7251+
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
7252+
const int64_t knl_n_total = knl_n_per_channel * c;
7253+
const int64_t patch_total = n * dst_w * dst_h * dst_d;
7254+
7255+
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
7256+
const int64_t batch_size = params->wsize / space_per_patch;
7257+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
7258+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
7259+
7260+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
7261+
7262+
void * tmp = params->wdata;
7263+
7264+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
7265+
const int64_t patch_start_batch = batch_i * patches_per_batch;
7266+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
7267+
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
7268+
7269+
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
7270+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
7271+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
7272+
7273+
for (int64_t p = patch_start; p < patch_end; ++p) {
7274+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7275+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7276+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
7277+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
7278+
const int64_t dst_y = p_in_depth / dst_w;
7279+
const int64_t dst_x = p_in_depth % dst_w;
7280+
7281+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
7282+
7283+
for (int64_t ic = 0; ic < c; ++ic) {
7284+
for (int64_t kz = 0; kz < knl_d; ++kz) {
7285+
for (int64_t ky = 0; ky < knl_h; ++ky) {
7286+
for (int64_t kx = 0; kx < knl_w; ++kx) {
7287+
const int64_t sz = dst_z * s2 + kz * d2 - p2;
7288+
const int64_t sy = dst_y * s1 + ky * d1 - p1;
7289+
const int64_t sx = dst_x * s0 + kx * d0 - p0;
7290+
7291+
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
7292+
7293+
float src_val;
7294+
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
7295+
src_val = 0.0f;
7296+
} else {
7297+
const int64_t cn_idx = batch_idx * c + ic;
7298+
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
7299+
src_val = *src_ptr;
7300+
}
7301+
7302+
char * element_ptr = dst_row + dst_idx * traits->type_size;
7303+
if (kernel_type == GGML_TYPE_F32) {
7304+
*(float *)element_ptr = src_val;
7305+
} else if (kernel_type == GGML_TYPE_F16) {
7306+
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
7307+
}
7308+
}
7309+
}
7310+
}
7311+
}
7312+
}
7313+
7314+
ggml_barrier(params->threadpool);
7315+
7316+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
7317+
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
7318+
7319+
ggml_barrier(params->threadpool);
7320+
7321+
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
7322+
const int64_t permute_start = params->ith * permute_per_thread;
7323+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
7324+
7325+
for (int64_t i = permute_start; i < permute_end; ++i) {
7326+
const int64_t p = patch_start_batch + i;
7327+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7328+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7329+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
7330+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
7331+
const int64_t dst_y = p_in_depth / dst_w;
7332+
const int64_t dst_x = p_in_depth % dst_w;
7333+
7334+
for (int64_t ioc = 0; ioc < oc; ++ioc) {
7335+
const float value = gemm_output[i * oc + ioc];
7336+
const int64_t ocn_idx = batch_idx * oc + ioc;
7337+
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
7338+
*dst_ptr = value;
7339+
}
7340+
}
7341+
}
7342+
}
7343+
7344+
void ggml_compute_forward_conv_3d(
7345+
const ggml_compute_params * params,
7346+
ggml_tensor * dst) {
7347+
const ggml_tensor * src0 = dst->src[0];
7348+
const ggml_tensor * src1 = dst->src[1];
7349+
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
7350+
}
7351+
72107352
// ggml_compute_forward_conv_transpose_2d
72117353

72127354
void ggml_compute_forward_conv_transpose_2d(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
7070
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7171
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7272
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
73+
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7374
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7475
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7576
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
975975
"IM2COL",
976976
"IM2COL_BACK",
977977
"CONV_2D",
978+
"CONV_3D",
978979
"CONV_2D_DW",
979980
"CONV_TRANSPOSE_2D",
980981
"POOL_1D",
@@ -1017,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10171018
"GLU",
10181019
};
10191020

1020-
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
1021+
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
10211022

10221023
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10231024
"none",
@@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10771078
"im2col(x)",
10781079
"im2col_back(x)",
10791080
"conv_2d(x)",
1081+
"conv_3d(x)",
10801082
"conv_2d_dw(x)",
10811083
"conv_transpose_2d(x)",
10821084
"pool_1d(x)",
@@ -1119,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11191121
"glu(x)",
11201122
};
11211123

1122-
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
1124+
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
11231125

11241126
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11251127

@@ -4480,6 +4482,56 @@ struct ggml_tensor * ggml_conv_2d_direct(
44804482
return result;
44814483
}
44824484

4485+
// ggml_conv_3d
4486+
4487+
struct ggml_tensor * ggml_conv_3d(
4488+
struct ggml_context * ctx,
4489+
struct ggml_tensor * a,
4490+
struct ggml_tensor * b,
4491+
int s0,
4492+
int s1,
4493+
int s2,
4494+
int p0,
4495+
int p1,
4496+
int p2,
4497+
int d0,
4498+
int d1,
4499+
int d2,
4500+
int c,
4501+
int n,
4502+
int oc) {
4503+
4504+
GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
4505+
GGML_ASSERT(b->ne[3] == (int64_t) c * n);
4506+
4507+
int64_t ne[4];
4508+
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4509+
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4510+
ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
4511+
ne[3] = (int64_t) oc * n;
4512+
4513+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4514+
4515+
ggml_set_op_params_i32(result, 0, s0);
4516+
ggml_set_op_params_i32(result, 1, s1);
4517+
ggml_set_op_params_i32(result, 2, s2);
4518+
ggml_set_op_params_i32(result, 3, p0);
4519+
ggml_set_op_params_i32(result, 4, p1);
4520+
ggml_set_op_params_i32(result, 5, p2);
4521+
ggml_set_op_params_i32(result, 6, d0);
4522+
ggml_set_op_params_i32(result, 7, d1);
4523+
ggml_set_op_params_i32(result, 8, d2);
4524+
ggml_set_op_params_i32(result, 9, c);
4525+
ggml_set_op_params_i32(result, 10, n);
4526+
ggml_set_op_params_i32(result, 11, oc);
4527+
4528+
result->op = GGML_OP_CONV_3D;
4529+
result->src[0] = a;
4530+
result->src[1] = b;
4531+
4532+
return result;
4533+
}
4534+
44834535
// ggml_conv_transpose_2d_p0
44844536

44854537
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {

0 commit comments

Comments
 (0)