Skip to content

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Aug 28, 2025

This PR does two things:

  1. Fuse upto 8 adds. Theoretically it can do any binary op, but adds seem like the only place it will help for now
  2. Fuse add with RMS norm if available

Results on a 4090

Master
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp256 6618.81 ± 52.21
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp512 8329.03 ± 57.80
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp1024 7761.23 ± 26.59
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp2048 7213.90 ± 13.25
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp4096 6406.37 ± 4.48
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp8192 5253.41 ± 3.44
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg128 218.32 ± 0.42

PR

Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp256 6639.05 ± 49.04
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp512 8357.07 ± 67.45
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp1024 7803.96 ± 25.31
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp2048 7248.48 ± 9.59
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp4096 6422.83 ± 4.76
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp8192 5253.89 ± 5.95
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg128 218.93 ± 0.47

Master

Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp256 4825.40 ± 37.22
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp512 6460.33 ± 26.83
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp1024 5993.38 ± 34.62
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp2048 5406.66 ± 17.89
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp4096 4547.56 ± 11.06
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp8192 3459.71 ± 3.10
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 tg128 183.70 ± 0.26

PR

Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp256 4931.39 ± 41.82
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp512 6519.48 ± 31.29
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp1024 6060.62 ± 31.74
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp2048 5463.69 ± 15.37
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp4096 4583.41 ± 9.40
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 pp8192 3479.89 ± 2.58
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 tg128 192.26 ± 0.61

@am17an am17an force-pushed the rms_norm_fused_add branch from 49cf98e to 4d10578 Compare August 28, 2025 09:47
@ggerganov
Copy link
Member

Could you also post results with -ub 256,512,1024,2048,4096 -b 4096 -p 4096?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 28, 2025
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing an extra case for ggml_cuda_can_fuse to handle the fused add?

@JohannesGaessler
Copy link
Collaborator

Sorry, I accidentally pressed the wrong button and submitted the review before I was done.

@am17an
Copy link
Collaborator Author

am17an commented Aug 28, 2025

Isn't this missing an extra case for ggml_cuda_can_fuse to handle the fused add?

No because it doesn't use it for the fused adds, I use ggml_can_fuse directly

@am17an
Copy link
Collaborator Author

am17an commented Aug 28, 2025

Could you also post results with -ub 256,512,1024,2048,4096 -b 4096 -p 4096?

Results

master

model size params backend ngl n_batch n_ubatch test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 256 pp4096 3581.38 ± 42.05
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 256 tg128 177.66 ± 1.79
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 512 pp4096 4412.67 ± 5.48
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 512 tg128 178.43 ± 1.66
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 1024 pp4096 4799.10 ± 3.90
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 1024 tg128 178.87 ± 0.75
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 2048 pp4096 4345.95 ± 3.55
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 2048 tg128 179.05 ± 0.96
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 4096 pp4096 3555.53 ± 5.27
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 4096 tg128 177.43 ± 2.22

this PR

model size params backend ngl n_batch n_ubatch test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 256 pp4096 3582.54 ± 61.03
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 256 tg128 181.03 ± 5.67
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 512 pp4096 4433.36 ± 31.03
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 512 tg128 184.74 ± 2.65
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 1024 pp4096 4822.68 ± 3.78
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 1024 tg128 180.21 ± 6.35
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 2048 pp4096 4335.25 ± 2.80
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 2048 tg128 181.33 ± 3.36
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 4096 pp4096 3562.96 ± 4.04
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 4096 4096 tg128 185.85 ± 0.86

@am17an am17an force-pushed the rms_norm_fused_add branch from 61f2b2a to 34ee6e6 Compare August 28, 2025 13:44
@am17an am17an force-pushed the rms_norm_fused_add branch from 34ee6e6 to b64ba1c Compare August 28, 2025 13:56
@JohannesGaessler
Copy link
Collaborator

Performance changes
GPU Model Microbatch size Test t/s 5a0e3ef t/s b64ba1c Speedup
MI50 gpt-oss 20B MXFP4 MoE 1 pp4096 23.28 23.26 1.00
MI50 gpt-oss 20B MXFP4 MoE 2 pp4096 33.03 33.08 1.00
MI50 gpt-oss 20B MXFP4 MoE 4 pp4096 57.07 57.12 1.00
MI50 gpt-oss 20B MXFP4 MoE 8 pp4096 99.42 99.44 1.00
MI50 gpt-oss 20B MXFP4 MoE 16 pp4096 148.96 149.06 1.00
MI50 gpt-oss 20B MXFP4 MoE 32 pp4096 203.37 203.42 1.00
MI50 gpt-oss 20B MXFP4 MoE 64 pp4096 222.38 222.66 1.00
MI50 gpt-oss 20B MXFP4 MoE 128 pp4096 280.49 280.75 1.00
MI50 gpt-oss 20B MXFP4 MoE 256 pp4096 298.24 297.85 1.00
MI50 gpt-oss 20B MXFP4 MoE 512 pp4096 280.57 280.42 1.00
MI50 gpt-oss 20B MXFP4 MoE 1024 pp4096 286.32 286.28 1.00
MI50 gpt-oss 20B MXFP4 MoE 2048 pp4096 260.40 260.47 1.00
MI50 gpt-oss 20B MXFP4 MoE 4096 pp4096 210.58 210.68 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 1 pp4096 35.11 36.08 1.03
MI50 qwen3moe 30B.A3B Q3_K_S 2 pp4096 52.38 53.29 1.02
MI50 qwen3moe 30B.A3B Q3_K_S 4 pp4096 70.57 71.37 1.01
MI50 qwen3moe 30B.A3B Q3_K_S 8 pp4096 64.51 64.80 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 16 pp4096 78.00 78.37 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 32 pp4096 71.09 71.23 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 64 pp4096 75.88 75.98 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 128 pp4096 84.83 84.93 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 256 pp4096 92.35 91.98 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 512 pp4096 88.54 88.25 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 1024 pp4096 94.21 94.51 1.00
MI50 qwen3moe 30B.A3B Q3_K_S 2048 pp4096 94.37 95.01 1.01
MI50 qwen3moe 30B.A3B Q3_K_S 4096 pp4096 94.18 94.51 1.00
P40 gpt-oss 20B MXFP4 MoE 1 pp4096 77.26 76.74 0.99
P40 gpt-oss 20B MXFP4 MoE 2 pp4096 94.00 94.08 1.00
P40 gpt-oss 20B MXFP4 MoE 4 pp4096 142.28 143.21 1.01
P40 gpt-oss 20B MXFP4 MoE 8 pp4096 206.17 209.82 1.02
P40 gpt-oss 20B MXFP4 MoE 16 pp4096 351.68 353.00 1.00
P40 gpt-oss 20B MXFP4 MoE 32 pp4096 428.56 429.07 1.00
P40 gpt-oss 20B MXFP4 MoE 64 pp4096 519.27 521.15 1.00
P40 gpt-oss 20B MXFP4 MoE 128 pp4096 747.82 753.56 1.01
P40 gpt-oss 20B MXFP4 MoE 256 pp4096 973.86 983.94 1.01
P40 gpt-oss 20B MXFP4 MoE 512 pp4096 1140.90 1148.81 1.01
P40 gpt-oss 20B MXFP4 MoE 1024 pp4096 1273.51 1276.24 1.00
P40 gpt-oss 20B MXFP4 MoE 2048 pp4096 1281.68 1278.85 1.00
P40 gpt-oss 20B MXFP4 MoE 4096 pp4096 1181.50 1185.88 1.00
P40 qwen3moe 30B.A3B Q3_K_S 1 pp4096 40.65 41.30 1.02
P40 qwen3moe 30B.A3B Q3_K_S 2 pp4096 60.10 60.94 1.01
P40 qwen3moe 30B.A3B Q3_K_S 4 pp4096 85.96 86.76 1.01
P40 qwen3moe 30B.A3B Q3_K_S 8 pp4096 121.43 122.27 1.01
P40 qwen3moe 30B.A3B Q3_K_S 16 pp4096 229.90 232.13 1.01
P40 qwen3moe 30B.A3B Q3_K_S 32 pp4096 299.91 297.29 0.99
P40 qwen3moe 30B.A3B Q3_K_S 64 pp4096 310.55 310.65 1.00
P40 qwen3moe 30B.A3B Q3_K_S 128 pp4096 441.50 441.33 1.00
P40 qwen3moe 30B.A3B Q3_K_S 256 pp4096 585.32 588.97 1.01
P40 qwen3moe 30B.A3B Q3_K_S 512 pp4096 709.18 715.48 1.01
P40 qwen3moe 30B.A3B Q3_K_S 1024 pp4096 816.83 822.88 1.01
P40 qwen3moe 30B.A3B Q3_K_S 2048 pp4096 870.88 880.18 1.01
P40 qwen3moe 30B.A3B Q3_K_S 4096 pp4096 896.42 903.11 1.01
RTX 3090 gpt-oss 20B MXFP4 MoE 1 pp4096 182.88 186.29 1.02
RTX 3090 gpt-oss 20B MXFP4 MoE 2 pp4096 175.46 176.52 1.01
RTX 3090 gpt-oss 20B MXFP4 MoE 4 pp4096 299.10 297.95 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 8 pp4096 483.69 483.82 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 16 pp4096 773.09 773.50 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 32 pp4096 1194.09 1199.63 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 64 pp4096 1634.92 1638.15 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 128 pp4096 2082.23 2074.87 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 256 pp4096 3115.25 3102.94 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 512 pp4096 4102.75 4115.83 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 1024 pp4096 4866.67 4889.53 1.00
RTX 3090 gpt-oss 20B MXFP4 MoE 2048 pp4096 5280.09 5312.56 1.01
RTX 3090 gpt-oss 20B MXFP4 MoE 4096 pp4096 5166.07 5200.66 1.01
RTX 3090 qwen3moe 30B.A3B Q4_0 1 pp4096 132.80 137.94 1.04
RTX 3090 qwen3moe 30B.A3B Q4_0 2 pp4096 122.56 126.83 1.03
RTX 3090 qwen3moe 30B.A3B Q4_0 4 pp4096 210.64 215.54 1.02
RTX 3090 qwen3moe 30B.A3B Q4_0 8 pp4096 341.99 348.88 1.02
RTX 3090 qwen3moe 30B.A3B Q4_0 16 pp4096 544.15 552.65 1.02
RTX 3090 qwen3moe 30B.A3B Q4_0 32 pp4096 820.64 829.34 1.01
RTX 3090 qwen3moe 30B.A3B Q4_0 64 pp4096 1189.53 1194.48 1.00
RTX 3090 qwen3moe 30B.A3B Q4_0 128 pp4096 1457.13 1470.68 1.01
RTX 3090 qwen3moe 30B.A3B Q4_0 256 pp4096 2284.14 2304.66 1.01
RTX 3090 qwen3moe 30B.A3B Q4_0 512 pp4096 3060.20 3100.22 1.01
RTX 3090 qwen3moe 30B.A3B Q4_0 1024 pp4096 3736.49 3807.04 1.02
RTX 3090 qwen3moe 30B.A3B Q4_0 2048 pp4096 4145.49 4273.44 1.03
RTX 3090 qwen3moe 30B.A3B Q4_0 4096 pp4096 4282.79 4408.95 1.03
RTX 4090 gpt-oss 20B MXFP4 MoE 1 pp4096 267.15 270.76 1.01
RTX 4090 gpt-oss 20B MXFP4 MoE 2 pp4096 278.45 280.76 1.01
RTX 4090 gpt-oss 20B MXFP4 MoE 4 pp4096 485.28 488.53 1.01
RTX 4090 gpt-oss 20B MXFP4 MoE 8 pp4096 810.72 815.68 1.01
RTX 4090 gpt-oss 20B MXFP4 MoE 16 pp4096 1336.37 1344.19 1.01
RTX 4090 gpt-oss 20B MXFP4 MoE 32 pp4096 2268.63 2278.60 1.00
RTX 4090 gpt-oss 20B MXFP4 MoE 64 pp4096 3358.71 3367.00 1.00
RTX 4090 gpt-oss 20B MXFP4 MoE 128 pp4096 4548.22 4548.72 1.00
RTX 4090 gpt-oss 20B MXFP4 MoE 256 pp4096 7055.43 7061.93 1.00
RTX 4090 gpt-oss 20B MXFP4 MoE 512 pp4096 9850.41 9860.10 1.00
RTX 4090 gpt-oss 20B MXFP4 MoE 1024 pp4096 11078.32 11094.10 1.00
RTX 4090 gpt-oss 20B MXFP4 MoE 2048 pp4096 10649.30 10565.77 0.99
RTX 4090 gpt-oss 20B MXFP4 MoE 4096 pp4096 8950.94 8984.82 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 1 pp4096 194.24 205.87 1.06
RTX 4090 qwen3moe 30B.A3B Q4_0 2 pp4096 202.61 211.62 1.04
RTX 4090 qwen3moe 30B.A3B Q4_0 4 pp4096 352.72 365.39 1.04
RTX 4090 qwen3moe 30B.A3B Q4_0 8 pp4096 580.35 597.34 1.03
RTX 4090 qwen3moe 30B.A3B Q4_0 16 pp4096 936.35 959.90 1.03
RTX 4090 qwen3moe 30B.A3B Q4_0 32 pp4096 1519.77 1556.32 1.02
RTX 4090 qwen3moe 30B.A3B Q4_0 64 pp4096 2456.92 2508.44 1.02
RTX 4090 qwen3moe 30B.A3B Q4_0 128 pp4096 3331.02 3388.49 1.02
RTX 4090 qwen3moe 30B.A3B Q4_0 256 pp4096 5387.09 5474.49 1.02
RTX 4090 qwen3moe 30B.A3B Q4_0 512 pp4096 7484.01 7585.34 1.01
RTX 4090 qwen3moe 30B.A3B Q4_0 1024 pp4096 8771.85 8870.23 1.01
RTX 4090 qwen3moe 30B.A3B Q4_0 2048 pp4096 8893.68 8864.01 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 4096 pp4096 8147.55 8185.93 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 1 pp4096 70.54 71.87 1.02
RX 6800 gpt-oss 20B MXFP4 MoE 2 pp4096 70.65 71.18 1.01
RX 6800 gpt-oss 20B MXFP4 MoE 4 pp4096 96.06 96.42 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 8 pp4096 121.69 122.06 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 16 pp4096 222.86 223.68 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 32 pp4096 248.06 248.62 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 64 pp4096 286.22 286.85 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 128 pp4096 372.44 373.28 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 256 pp4096 436.94 437.58 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 512 pp4096 451.83 452.73 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 1024 pp4096 480.73 481.87 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 2048 pp4096 450.47 451.91 1.00
RX 6800 gpt-oss 20B MXFP4 MoE 4096 pp4096 373.07 373.09 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 1 pp4096 34.53 36.33 1.05
RX 6800 qwen3moe 30B.A3B Q3_K_S 2 pp4096 45.90 47.34 1.03
RX 6800 qwen3moe 30B.A3B Q3_K_S 4 pp4096 56.33 57.42 1.02
RX 6800 qwen3moe 30B.A3B Q3_K_S 8 pp4096 73.48 74.49 1.01
RX 6800 qwen3moe 30B.A3B Q3_K_S 16 pp4096 92.12 93.06 1.01
RX 6800 qwen3moe 30B.A3B Q3_K_S 32 pp4096 110.67 111.46 1.01
RX 6800 qwen3moe 30B.A3B Q3_K_S 64 pp4096 116.16 116.67 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 128 pp4096 154.57 155.28 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 256 pp4096 179.42 180.10 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 512 pp4096 187.46 188.22 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 1024 pp4096 202.85 203.17 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 2048 pp4096 201.44 202.10 1.00
RX 6800 qwen3moe 30B.A3B Q3_K_S 4096 pp4096 199.23 200.27 1.01

@am17an am17an merged commit 009b709 into ggml-org:master Aug 29, 2025
48 checks passed
@am17an am17an deleted the rms_norm_fused_add branch August 29, 2025 03:36
@CISC
Copy link
Collaborator

CISC commented Aug 29, 2025

There seems to be another issue, this PR introduces a segfault in test-backend-ops:
https://github.com/ggml-org/ci/blob/results/llama.cpp/00/9b709d6efd24820ac67765ed339a72dc797814/ggml-4-x86-cuda-v100/stdall#L15422

@am17an
Copy link
Collaborator Author

am17an commented Aug 29, 2025

There seems to be another issue, this PR introduces a segfault in test-backend-ops: https://github.com/ggml-org/ci/blob/results/llama.cpp/00/9b709d6efd24820ac67765ed339a72dc797814/ggml-4-x86-cuda-v100/stdall#L15422

I thought this test would be run with the CI, I was wrong

@CISC
Copy link
Collaborator

CISC commented Aug 29, 2025

There seems to be another issue, this PR introduces a segfault in test-backend-ops: https://github.com/ggml-org/ci/blob/results/llama.cpp/00/9b709d6efd24820ac67765ed339a72dc797814/ggml-4-x86-cuda-v100/stdall#L15422

I thought this test would be run with the CI, I was wrong

You have to add ggml-ci to your commit comment to do the full tests.

@CISC
Copy link
Collaborator

CISC commented Aug 29, 2025

You have to add ggml-ci to your commit comment to do the full tests.

I think it only works in the comment, not the title. :)

@am17an
Copy link
Collaborator Author

am17an commented Aug 29, 2025

You have to add ggml-ci to your commit comment to do the full tests.

I think it only works in the comment, not the title. :)

Ok thanks! Trying it in #15660. I think it might be good to auto trigger test-backend-ops, at least for the affected backends in a PR

@CISC
Copy link
Collaborator

CISC commented Aug 29, 2025

You have to add ggml-ci to your commit comment to do the full tests.

I think it only works in the comment, not the title. :)

Ok thanks! Trying it in #15660. I think it might be good to auto trigger test-backend-ops, at least for the affected backends in a PR

I saw, but it only works in comment/description, so it wasn't triggered.

Edit: And yes, it would be useful to somehow autotrigger ggml-ci for the backend in question, but I don't think that's easily doable.

qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 30, 2025
* CUDA: fused add with rms_norm_mul

* Non-broadcast fuse works

* Add fused adds

* format

* Remove n_fuse from template params

* Address review comments

* Move template inside binbcast
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants