Skip to content

Commit b52f00e

Browse files
Alexei Starovoitovdavem330
authored andcommitted
x86: bpf_jit: implement bpf_tail_call() helper
bpf_tail_call() arguments: ctx - context pointer jmp_table - one of BPF_MAP_TYPE_PROG_ARRAY maps used as the jump table index - index in the jump table In this implementation x64 JIT bypasses stack unwind and jumps into the callee program after prologue, so the callee program reuses the same stack. The logic can be roughly expressed in C like: u32 tail_call_cnt; void *jumptable[2] = { &&label1, &&label2 }; int bpf_prog1(void *ctx) { label1: ... } int bpf_prog2(void *ctx) { label2: ... } int bpf_prog1(void *ctx) { ... if (tail_call_cnt++ < MAX_TAIL_CALL_CNT) goto *jumptable[index]; ... and pass my 'ctx' to callee ... ... fall through if no entry in jumptable ... } Note that 'skip current program epilogue and next program prologue' is an optimization. Other JITs don't have to do it the same way. >From safety point of view it's valid as well, since programs always initialize the stack before use, so any residue in the stack left by the current program is not going be read. The same verifier checks are done for the calls from the kernel into all bpf programs. Signed-off-by: Alexei Starovoitov <ast@plumgrid.com> Signed-off-by: David S. Miller <davem@davemloft.net>
1 parent 04fd61a commit b52f00e

File tree

1 file changed

+126
-24
lines changed

1 file changed

+126
-24
lines changed

arch/x86/net/bpf_jit_comp.c

Lines changed: 126 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <linux/filter.h>
1313
#include <linux/if_vlan.h>
1414
#include <asm/cacheflush.h>
15+
#include <linux/bpf.h>
1516

1617
int bpf_jit_enable __read_mostly;
1718

@@ -37,7 +38,8 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
3738
return ptr + len;
3839
}
3940

40-
#define EMIT(bytes, len) do { prog = emit_code(prog, bytes, len); } while (0)
41+
#define EMIT(bytes, len) \
42+
do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
4143

4244
#define EMIT1(b1) EMIT(b1, 1)
4345
#define EMIT2(b1, b2) EMIT((b1) + ((b2) << 8), 2)
@@ -186,31 +188,31 @@ struct jit_context {
186188
#define BPF_MAX_INSN_SIZE 128
187189
#define BPF_INSN_SAFETY 64
188190

189-
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
190-
int oldproglen, struct jit_context *ctx)
191+
#define STACKSIZE \
192+
(MAX_BPF_STACK + \
193+
32 /* space for rbx, r13, r14, r15 */ + \
194+
8 /* space for skb_copy_bits() buffer */)
195+
196+
#define PROLOGUE_SIZE 51
197+
198+
/* emit x64 prologue code for BPF program and check it's size.
199+
* bpf_tail_call helper will skip it while jumping into another program
200+
*/
201+
static void emit_prologue(u8 **pprog)
191202
{
192-
struct bpf_insn *insn = bpf_prog->insnsi;
193-
int insn_cnt = bpf_prog->len;
194-
bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
195-
bool seen_exit = false;
196-
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
197-
int i;
198-
int proglen = 0;
199-
u8 *prog = temp;
200-
int stacksize = MAX_BPF_STACK +
201-
32 /* space for rbx, r13, r14, r15 */ +
202-
8 /* space for skb_copy_bits() buffer */;
203+
u8 *prog = *pprog;
204+
int cnt = 0;
203205

204206
EMIT1(0x55); /* push rbp */
205207
EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
206208

207-
/* sub rsp, stacksize */
208-
EMIT3_off32(0x48, 0x81, 0xEC, stacksize);
209+
/* sub rsp, STACKSIZE */
210+
EMIT3_off32(0x48, 0x81, 0xEC, STACKSIZE);
209211

210212
/* all classic BPF filters use R6(rbx) save it */
211213

212214
/* mov qword ptr [rbp-X],rbx */
213-
EMIT3_off32(0x48, 0x89, 0x9D, -stacksize);
215+
EMIT3_off32(0x48, 0x89, 0x9D, -STACKSIZE);
214216

215217
/* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
216218
* as temporary, so all tcpdump filters need to spill/fill R7(r13) and
@@ -221,16 +223,112 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
221223
*/
222224

223225
/* mov qword ptr [rbp-X],r13 */
224-
EMIT3_off32(0x4C, 0x89, 0xAD, -stacksize + 8);
226+
EMIT3_off32(0x4C, 0x89, 0xAD, -STACKSIZE + 8);
225227
/* mov qword ptr [rbp-X],r14 */
226-
EMIT3_off32(0x4C, 0x89, 0xB5, -stacksize + 16);
228+
EMIT3_off32(0x4C, 0x89, 0xB5, -STACKSIZE + 16);
227229
/* mov qword ptr [rbp-X],r15 */
228-
EMIT3_off32(0x4C, 0x89, 0xBD, -stacksize + 24);
230+
EMIT3_off32(0x4C, 0x89, 0xBD, -STACKSIZE + 24);
229231

230232
/* clear A and X registers */
231233
EMIT2(0x31, 0xc0); /* xor eax, eax */
232234
EMIT3(0x4D, 0x31, 0xED); /* xor r13, r13 */
233235

236+
/* clear tail_cnt: mov qword ptr [rbp-X], rax */
237+
EMIT3_off32(0x48, 0x89, 0x85, -STACKSIZE + 32);
238+
239+
BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
240+
*pprog = prog;
241+
}
242+
243+
/* generate the following code:
244+
* ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
245+
* if (index >= array->map.max_entries)
246+
* goto out;
247+
* if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
248+
* goto out;
249+
* prog = array->prog[index];
250+
* if (prog == NULL)
251+
* goto out;
252+
* goto *(prog->bpf_func + prologue_size);
253+
* out:
254+
*/
255+
static void emit_bpf_tail_call(u8 **pprog)
256+
{
257+
u8 *prog = *pprog;
258+
int label1, label2, label3;
259+
int cnt = 0;
260+
261+
/* rdi - pointer to ctx
262+
* rsi - pointer to bpf_array
263+
* rdx - index in bpf_array
264+
*/
265+
266+
/* if (index >= array->map.max_entries)
267+
* goto out;
268+
*/
269+
EMIT4(0x48, 0x8B, 0x46, /* mov rax, qword ptr [rsi + 16] */
270+
offsetof(struct bpf_array, map.max_entries));
271+
EMIT3(0x48, 0x39, 0xD0); /* cmp rax, rdx */
272+
#define OFFSET1 44 /* number of bytes to jump */
273+
EMIT2(X86_JBE, OFFSET1); /* jbe out */
274+
label1 = cnt;
275+
276+
/* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
277+
* goto out;
278+
*/
279+
EMIT2_off32(0x8B, 0x85, -STACKSIZE + 36); /* mov eax, dword ptr [rbp - 516] */
280+
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
281+
#define OFFSET2 33
282+
EMIT2(X86_JA, OFFSET2); /* ja out */
283+
label2 = cnt;
284+
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
285+
EMIT2_off32(0x89, 0x85, -STACKSIZE + 36); /* mov dword ptr [rbp - 516], eax */
286+
287+
/* prog = array->prog[index]; */
288+
EMIT4(0x48, 0x8D, 0x44, 0xD6); /* lea rax, [rsi + rdx * 8 + 0x50] */
289+
EMIT1(offsetof(struct bpf_array, prog));
290+
EMIT3(0x48, 0x8B, 0x00); /* mov rax, qword ptr [rax] */
291+
292+
/* if (prog == NULL)
293+
* goto out;
294+
*/
295+
EMIT4(0x48, 0x83, 0xF8, 0x00); /* cmp rax, 0 */
296+
#define OFFSET3 10
297+
EMIT2(X86_JE, OFFSET3); /* je out */
298+
label3 = cnt;
299+
300+
/* goto *(prog->bpf_func + prologue_size); */
301+
EMIT4(0x48, 0x8B, 0x40, /* mov rax, qword ptr [rax + 32] */
302+
offsetof(struct bpf_prog, bpf_func));
303+
EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE); /* add rax, prologue_size */
304+
305+
/* now we're ready to jump into next BPF program
306+
* rdi == ctx (1st arg)
307+
* rax == prog->bpf_func + prologue_size
308+
*/
309+
EMIT2(0xFF, 0xE0); /* jmp rax */
310+
311+
/* out: */
312+
BUILD_BUG_ON(cnt - label1 != OFFSET1);
313+
BUILD_BUG_ON(cnt - label2 != OFFSET2);
314+
BUILD_BUG_ON(cnt - label3 != OFFSET3);
315+
*pprog = prog;
316+
}
317+
318+
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
319+
int oldproglen, struct jit_context *ctx)
320+
{
321+
struct bpf_insn *insn = bpf_prog->insnsi;
322+
int insn_cnt = bpf_prog->len;
323+
bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
324+
bool seen_exit = false;
325+
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
326+
int i, cnt = 0;
327+
int proglen = 0;
328+
u8 *prog = temp;
329+
330+
emit_prologue(&prog);
331+
234332
if (seen_ld_abs) {
235333
/* r9d : skb->len - skb->data_len (headlen)
236334
* r10 : skb->data
@@ -739,6 +837,10 @@ xadd: if (is_imm8(insn->off))
739837
}
740838
break;
741839

840+
case BPF_JMP | BPF_CALL | BPF_X:
841+
emit_bpf_tail_call(&prog);
842+
break;
843+
742844
/* cond jump */
743845
case BPF_JMP | BPF_JEQ | BPF_X:
744846
case BPF_JMP | BPF_JNE | BPF_X:
@@ -891,13 +993,13 @@ xadd: if (is_imm8(insn->off))
891993
/* update cleanup_addr */
892994
ctx->cleanup_addr = proglen;
893995
/* mov rbx, qword ptr [rbp-X] */
894-
EMIT3_off32(0x48, 0x8B, 0x9D, -stacksize);
996+
EMIT3_off32(0x48, 0x8B, 0x9D, -STACKSIZE);
895997
/* mov r13, qword ptr [rbp-X] */
896-
EMIT3_off32(0x4C, 0x8B, 0xAD, -stacksize + 8);
998+
EMIT3_off32(0x4C, 0x8B, 0xAD, -STACKSIZE + 8);
897999
/* mov r14, qword ptr [rbp-X] */
898-
EMIT3_off32(0x4C, 0x8B, 0xB5, -stacksize + 16);
1000+
EMIT3_off32(0x4C, 0x8B, 0xB5, -STACKSIZE + 16);
8991001
/* mov r15, qword ptr [rbp-X] */
900-
EMIT3_off32(0x4C, 0x8B, 0xBD, -stacksize + 24);
1002+
EMIT3_off32(0x4C, 0x8B, 0xBD, -STACKSIZE + 24);
9011003

9021004
EMIT1(0xC9); /* leave */
9031005
EMIT1(0xC3); /* ret */

0 commit comments

Comments
 (0)