-
Notifications
You must be signed in to change notification settings - Fork 12.3k
ggml:metal Add POOL2D op and fix IM2COL in Metal backend for running MobileVLM_V2. #9943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml:metal Add POOL2D op and fix IM2COL in Metal backend for running MobileVLM_V2. #9943
Conversation
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done!
ggml/src/ggml-metal.m
Outdated
GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, | ||
GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, | |
GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, | |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, | |
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied in 008484799146274784a5df692fd7b7508805be83
ggml/src/ggml-metal.m
Outdated
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true); | ||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true); | |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true); | |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); | |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied in 008484799146274784a5df692fd7b7508805be83
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; | ||
const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup; | ||
|
||
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to keep just the _ext
variant of the kernel? Does the old kernel have a significant advantage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary: To keep backward compatibility and performance in smaller computations, I suggest you keep the old kernel.
details
I’ve done some investigations after you gave me this comment.
The new kernel(a.k.a _ext) has performance degradation in small size of N:
$ test-backend-ops perf -o IM2COL -b Metal
NOTE: 9 x 113 is 1017, 9 x 114 is 1026 which exceeds limits.
original + _ext, M
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 229348 runs - 4.38 us/run - 99 kB/run - 0.77 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 196584 runs - 5.24 us/run - 199 kB/run - 1.51 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 147438 runs - 6.82 us/run - 351 kB/run - 2.73 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 106483 runs - 9.67 us/run - 703 kB/run - 5.34 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.79 us/run - 354 kB/run - 2.57 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.74 us/run - 354 kB/run - 2.58 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 49146 runs - 21.43 us/run - 622 kB/run - 4.61 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 32764 runs - 35.12 us/run - 1244 kB/run - 8.45 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.60 us/run - 3186 kB/run - 20.37 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.71 us/run - 6372 kB/run - 21.00 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 143.44 us/run - 6372 kB/run - 21.18 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 283.55 us/run - 12744 kB/run - 21.43 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 175.97 us/run - 7965 kB/run - 21.58 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 352.93 us/run - 15930 kB/run - 21.52 GB/s
only _ext, M - worst case.
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 73719 runs - 14.48 us/run - 99 kB/run - 0.73 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 49146 runs - 24.31 us/run - 199 kB/run - 1.30 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 17.90 us/run - 351 kB/run - 2.68 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 40955 runs - 28.89 us/run - 703 kB/run - 4.64 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.54 us/run - 354 kB/run - 2.61 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.06 us/run - 354 kB/run - 2.68 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 20.35 us/run - 622 kB/run - 4.17 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 32764 runs - 34.79 us/run - 1244 kB/run - 8.53 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.19 us/run - 3186 kB/run - 20.48 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.72 us/run - 6372 kB/run - 21.00 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 142.86 us/run - 6372 kB/run - 21.27 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 283.70 us/run - 12744 kB/run - 21.42 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 175.84 us/run - 7965 kB/run - 21.60 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 350.67 us/run - 15930 kB/run - 21.66 GB/s
only_ext, MIN(N, M)
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 131056 runs - 7.78 us/run - 99 kB/run - 0.76 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.24 us/run - 199 kB/run - 1.54 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.58 us/run - 351 kB/run - 2.63 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 17.85 us/run - 703 kB/run - 5.37 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.80 us/run - 354 kB/run - 2.61 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.72 us/run - 354 kB/run - 2.62 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 65528 runs - 16.90 us/run - 622 kB/run - 4.39 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 40955 runs - 28.84 us/run - 1244 kB/run - 8.23 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.75 us/run - 3186 kB/run - 20.32 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.31 us/run - 6372 kB/run - 21.06 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 143.02 us/run - 6372 kB/run - 21.25 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 286.72 us/run - 12744 kB/run - 21.20 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 176.62 us/run - 7965 kB/run - 21.51 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 351.21 us/run - 15930 kB/run - 21.63 GB/s
original + _ext, MIN(N, M) - best case IMO.
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 221157 runs - 4.53 us/run - 99 kB/run - 0.78 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 196584 runs - 5.14 us/run - 199 kB/run - 1.54 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 147438 runs - 6.80 us/run - 351 kB/run - 2.74 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 114674 runs - 9.23 us/run - 703 kB/run - 5.19 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 81910 runs - 12.44 us/run - 354 kB/run - 2.72 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.55 us/run - 354 kB/run - 2.66 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 17.50 us/run - 622 kB/run - 4.85 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 40955 runs - 28.67 us/run - 1244 kB/run - 8.28 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.53 us/run - 3186 kB/run - 20.38 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.56 us/run - 6372 kB/run - 21.02 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 143.35 us/run - 6372 kB/run - 21.20 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 283.69 us/run - 12744 kB/run - 21.42 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 176.14 us/run - 7965 kB/run - 21.56 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 351.41 us/run - 15930 kB/run - 21.62 GB/s
I guess this comes from wasting threads in _ext kernel(in case of tpitg_0 >= N
) series when N is smaller than M(maxTotalThreadsPerThreadgroup) and it gets worse M - N gap is getting bigger.
The Apple official document told me that they have new API spreading threads fit into the given grid. However, it’s supporting from Metal3/Apple4(you can find this in the Metal feature set table with the keyword nonuniform
) so I didn’t check it because it narrows runnable devices of llama.cpp.
I applied original + _ext, MIN(N, M)
to this PR which is the best from those investigations. Thanks for your review. I could do more optimization from your comment. :-^)
ggml/src/ggml-metal.metal
Outdated
device const float* src0, | ||
device float* dst, | ||
constant int32_t& k0, | ||
constant int32_t& k1, | ||
constant int32_t& s0, | ||
constant int32_t& s1, | ||
constant int32_t& p0, | ||
constant int32_t& p1, | ||
constant int64_t& IH, | ||
constant int64_t& IW, | ||
constant int64_t& OH, | ||
constant int64_t& OW, | ||
constant int64_t& parallel_elements, | ||
uint gid[[thread_position_in_grid]]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device const float* src0, | |
device float* dst, | |
constant int32_t& k0, | |
constant int32_t& k1, | |
constant int32_t& s0, | |
constant int32_t& s1, | |
constant int32_t& p0, | |
constant int32_t& p1, | |
constant int64_t& IH, | |
constant int64_t& IW, | |
constant int64_t& OH, | |
constant int64_t& OW, | |
constant int64_t& parallel_elements, | |
uint gid[[thread_position_in_grid]]) { | |
device const float * src0, | |
device float * dst, | |
constant int32_t & k0, | |
constant int32_t & k1, | |
constant int32_t & s0, | |
constant int32_t & s1, | |
constant int32_t & p0, | |
constant int32_t & p1, | |
constant int64_t & IH, | |
constant int64_t & IW, | |
constant int64_t & OH, | |
constant int64_t & OW, | |
constant int64_t & parallel_elements, | |
uint gid[[thread_position_in_grid]]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied in 008484799146274784a5df692fd7b7508805be83
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
- original IM2COL kernel + _ext with MIN() Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
ggml/src/ggml-metal.m
Outdated
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, avg_pool_2d_f32, true); | ||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, max_pool_2d_f32, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel names also need to be updated to follow the max-prefix naming convention:
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, avg_pool_2d_f32, true); | |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, max_pool_2d_f32, true); | |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); | |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad.
applied in bb9949b
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are a few more minor suggestions to apply:
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index ed4d8326..0267b002 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -854,7 +854,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_POOL_1D:
return false;
case GGML_OP_POOL_2D:
- return true;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
@@ -2554,6 +2553,8 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2620,7 +2621,7 @@ static void ggml_metal_encode_node(
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
if (is_gt_mttpt) {
- [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
+ [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
@@ -3034,9 +3035,10 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_POOL_2D:
{
+ GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
- const int32_t* opts = dst->op_params;
+ const int32_t * opts = dst->op_params;
enum ggml_op_pool op = opts[0];
id<MTLComputePipelineState> pipeline = nil;
@@ -3063,7 +3065,7 @@ static void ggml_metal_encode_node(
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
- const int64_t N = dst->ne[3];
+ const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 1fb05cce..71b58be1 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -6479,15 +6479,16 @@ kernel void kernel_pool_2d_max_f32(
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;
- device const float* i_ptr = src0 + nc * I_HW;
- device float* o_ptr = dst + nc * O_HW;
+ device const float * i_ptr = src0 + nc * I_HW;
+ device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1;
- const int bh = MAX(0, start_h);
+ const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
- const int bw = MAX(0, start_w);
+ const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
+
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
@@ -6495,23 +6496,24 @@ kernel void kernel_pool_2d_max_f32(
res = MAX(res, i_ptr[i * IW + j]);
}
}
+
o_ptr[cur_oh * OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
- device const float* src0,
- device float* dst,
- constant int32_t& k0,
- constant int32_t& k1,
- constant int32_t& s0,
- constant int32_t& s1,
- constant int32_t& p0,
- constant int32_t& p1,
- constant int64_t& IH,
- constant int64_t& IW,
- constant int64_t& OH,
- constant int64_t& OW,
- constant int64_t& parallel_elements,
+ device const float * src0,
+ device float * dst,
+ constant int32_t & k0,
+ constant int32_t & k1,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int64_t & IH,
+ constant int64_t & IW,
+ constant int64_t & OH,
+ constant int64_t & OW,
+ constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) {
@@ -6525,17 +6527,18 @@ kernel void kernel_pool_2d_avg_f32(
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;
- device const float* i_ptr = src0 + nc * I_HW;
- device float* o_ptr = dst + nc * O_HW;
+ device const float * i_ptr = src0 + nc * I_HW;
+ device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1;
- const int bh = MAX(0, start_h);
+ const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
- const int bw = MAX(0, start_w);
+ const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (k0 * k1);
+
float res = 0;
for (int i = bh; i < eh; i += 1) {
@@ -6544,5 +6547,6 @@ kernel void kernel_pool_2d_avg_f32(
res += cur * scale;
}
}
+
o_ptr[cur_oh * OW + cur_ow] = res;
}
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
applied in 746e79e. 👍 |
* add pool_2d Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * fix im2col and add unittest for N>=1024 Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * add tests for N % 1024 != 0 Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * remove trailing whitespaces Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply suggestions Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply more optimization - original IM2COL kernel + _ext with MIN() Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply review: change kernel name of pool_2d Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply review Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * fix more formatting and enhance readability Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> --------- Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
* add pool_2d Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * fix im2col and add unittest for N>=1024 Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * add tests for N % 1024 != 0 Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * remove trailing whitespaces Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply suggestions Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply more optimization - original IM2COL kernel + _ext with MIN() Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply review: change kernel name of pool_2d Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply review Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * fix more formatting and enhance readability Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> --------- Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
Summary
F32 POOL2D operations(AVG, MAX) are added.
Because MobileVLM_V2 needs im2col with N = 2048 which is larger than maxTotalThreadsPerThreadgroup of M3 Max, fix IM2COL op in Metal backend to support large N and add tests for it.
I'd run
GG_BUILD_METAL=1 bash ci/run.sh ./tmp/results ./tmp/mnt
with python 3.10.14 in macOS Sonoma 14.5, Apple M3 Max and it had been done without failure.I have read the contributing guidelines
Self-reported review complexity:
Running MobileVLM_V2
Test model : MobileVLM V2 1.7B (https://huggingface.co/ZiangWu/MobileVLM_V2-1.7B-GGUF)
Test image : https://avatars.githubusercontent.com/u/1991296?v=4
full test log
test-backend-ops
I'd added test cases for large N.
full test-backend-ops log
performance test
I've done this by adding perf tests with the below patch which is not included in this PR:
full performance test log
I hope this helps you. :)