Forráskód Böngészése

Merge branch 'x86-jit'

Daniel Borkmann says:

====================
Couple of minor improvements to the x64 JIT I had still around from
pre merge window in order to shrink the image size further. Added
test cases for kselftests too as well as running Cilium workloads on
them w/o issues.
====================

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Alexei Starovoitov 7 éve
szülő
commit
7d72637eb3
2 módosított fájl, 221 hozzáadás és 87 törlés
  1. 132 87
      arch/x86/net/bpf_jit_comp.c
  2. 89 0
      tools/testing/selftests/bpf/test_verifier.c

+ 132 - 87
arch/x86/net/bpf_jit_comp.c

@@ -60,7 +60,12 @@ static bool is_imm8(int value)
 
 static bool is_simm32(s64 value)
 {
-	return value == (s64) (s32) value;
+	return value == (s64)(s32)value;
+}
+
+static bool is_uimm32(u64 value)
+{
+	return value == (u64)(u32)value;
 }
 
 /* mov dst, src */
@@ -211,7 +216,7 @@ struct jit_context {
 /* emit x64 prologue code for BPF program and check it's size.
  * bpf_tail_call helper will skip it while jumping into another program
  */
-static void emit_prologue(u8 **pprog, u32 stack_depth)
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
 {
 	u8 *prog = *pprog;
 	int cnt = 0;
@@ -246,18 +251,21 @@ static void emit_prologue(u8 **pprog, u32 stack_depth)
 	/* mov qword ptr [rbp+24],r15 */
 	EMIT4(0x4C, 0x89, 0x7D, 24);
 
-	/* Clear the tail call counter (tail_call_cnt): for eBPF tail calls
-	 * we need to reset the counter to 0. It's done in two instructions,
-	 * resetting rax register to 0 (xor on eax gets 0 extended), and
-	 * moving it to the counter location.
-	 */
+	if (!ebpf_from_cbpf) {
+		/* Clear the tail call counter (tail_call_cnt): for eBPF tail
+		 * calls we need to reset the counter to 0. It's done in two
+		 * instructions, resetting rax register to 0, and moving it
+		 * to the counter location.
+		 */
+
+		/* xor eax, eax */
+		EMIT2(0x31, 0xc0);
+		/* mov qword ptr [rbp+32], rax */
+		EMIT4(0x48, 0x89, 0x45, 32);
 
-	/* xor eax, eax */
-	EMIT2(0x31, 0xc0);
-	/* mov qword ptr [rbp+32], rax */
-	EMIT4(0x48, 0x89, 0x45, 32);
+		BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
+	}
 
-	BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
 	*pprog = prog;
 }
 
@@ -355,6 +363,86 @@ static void emit_load_skb_data_hlen(u8 **pprog)
 	*pprog = prog;
 }
 
+static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
+			   u32 dst_reg, const u32 imm32)
+{
+	u8 *prog = *pprog;
+	u8 b1, b2, b3;
+	int cnt = 0;
+
+	/* optimization: if imm32 is positive, use 'mov %eax, imm32'
+	 * (which zero-extends imm32) to save 2 bytes.
+	 */
+	if (sign_propagate && (s32)imm32 < 0) {
+		/* 'mov %rax, imm32' sign extends imm32 */
+		b1 = add_1mod(0x48, dst_reg);
+		b2 = 0xC7;
+		b3 = 0xC0;
+		EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
+		goto done;
+	}
+
+	/* optimization: if imm32 is zero, use 'xor %eax, %eax'
+	 * to save 3 bytes.
+	 */
+	if (imm32 == 0) {
+		if (is_ereg(dst_reg))
+			EMIT1(add_2mod(0x40, dst_reg, dst_reg));
+		b2 = 0x31; /* xor */
+		b3 = 0xC0;
+		EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
+		goto done;
+	}
+
+	/* mov %eax, imm32 */
+	if (is_ereg(dst_reg))
+		EMIT1(add_1mod(0x40, dst_reg));
+	EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
+done:
+	*pprog = prog;
+}
+
+static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
+			   const u32 imm32_hi, const u32 imm32_lo)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
+		/* For emitting plain u32, where sign bit must not be
+		 * propagated LLVM tends to load imm64 over mov32
+		 * directly, so save couple of bytes by just doing
+		 * 'mov %eax, imm32' instead.
+		 */
+		emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
+	} else {
+		/* movabsq %rax, imm64 */
+		EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
+		EMIT(imm32_lo, 4);
+		EMIT(imm32_hi, 4);
+	}
+
+	*pprog = prog;
+}
+
+static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	if (is64) {
+		/* mov dst, src */
+		EMIT_mov(dst_reg, src_reg);
+	} else {
+		/* mov32 dst, src */
+		if (is_ereg(dst_reg) || is_ereg(src_reg))
+			EMIT1(add_2mod(0x40, dst_reg, src_reg));
+		EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
+	}
+
+	*pprog = prog;
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 		  int oldproglen, struct jit_context *ctx)
 {
@@ -368,7 +456,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 	int proglen = 0;
 	u8 *prog = temp;
 
-	emit_prologue(&prog, bpf_prog->aux->stack_depth);
+	emit_prologue(&prog, bpf_prog->aux->stack_depth,
+		      bpf_prog_was_classic(bpf_prog));
 
 	if (seen_ld_abs)
 		emit_load_skb_data_hlen(&prog);
@@ -377,7 +466,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 		const s32 imm32 = insn->imm;
 		u32 dst_reg = insn->dst_reg;
 		u32 src_reg = insn->src_reg;
-		u8 b1 = 0, b2 = 0, b3 = 0;
+		u8 b2 = 0, b3 = 0;
 		s64 jmp_offset;
 		u8 jmp_cond;
 		bool reload_skb_data;
@@ -413,16 +502,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 			EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
 			break;
 
-			/* mov dst, src */
 		case BPF_ALU64 | BPF_MOV | BPF_X:
-			EMIT_mov(dst_reg, src_reg);
-			break;
-
-			/* mov32 dst, src */
 		case BPF_ALU | BPF_MOV | BPF_X:
-			if (is_ereg(dst_reg) || is_ereg(src_reg))
-				EMIT1(add_2mod(0x40, dst_reg, src_reg));
-			EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
+			emit_mov_reg(&prog,
+				     BPF_CLASS(insn->code) == BPF_ALU64,
+				     dst_reg, src_reg);
 			break;
 
 			/* neg dst */
@@ -485,58 +569,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 			break;
 
 		case BPF_ALU64 | BPF_MOV | BPF_K:
-			/* optimization: if imm32 is positive,
-			 * use 'mov eax, imm32' (which zero-extends imm32)
-			 * to save 2 bytes
-			 */
-			if (imm32 < 0) {
-				/* 'mov rax, imm32' sign extends imm32 */
-				b1 = add_1mod(0x48, dst_reg);
-				b2 = 0xC7;
-				b3 = 0xC0;
-				EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
-				break;
-			}
-
 		case BPF_ALU | BPF_MOV | BPF_K:
-			/* optimization: if imm32 is zero, use 'xor <dst>,<dst>'
-			 * to save 3 bytes.
-			 */
-			if (imm32 == 0) {
-				if (is_ereg(dst_reg))
-					EMIT1(add_2mod(0x40, dst_reg, dst_reg));
-				b2 = 0x31; /* xor */
-				b3 = 0xC0;
-				EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
-				break;
-			}
-
-			/* mov %eax, imm32 */
-			if (is_ereg(dst_reg))
-				EMIT1(add_1mod(0x40, dst_reg));
-			EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
+			emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
+				       dst_reg, imm32);
 			break;
 
 		case BPF_LD | BPF_IMM | BPF_DW:
-			/* optimization: if imm64 is zero, use 'xor <dst>,<dst>'
-			 * to save 7 bytes.
-			 */
-			if (insn[0].imm == 0 && insn[1].imm == 0) {
-				b1 = add_2mod(0x48, dst_reg, dst_reg);
-				b2 = 0x31; /* xor */
-				b3 = 0xC0;
-				EMIT3(b1, b2, add_2reg(b3, dst_reg, dst_reg));
-
-				insn++;
-				i++;
-				break;
-			}
-
-			/* movabsq %rax, imm64 */
-			EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
-			EMIT(insn[0].imm, 4);
-			EMIT(insn[1].imm, 4);
-
+			emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
 			insn++;
 			i++;
 			break;
@@ -593,36 +632,38 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 		case BPF_ALU | BPF_MUL | BPF_X:
 		case BPF_ALU64 | BPF_MUL | BPF_K:
 		case BPF_ALU64 | BPF_MUL | BPF_X:
-			EMIT1(0x50); /* push rax */
-			EMIT1(0x52); /* push rdx */
+		{
+			bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
+
+			if (dst_reg != BPF_REG_0)
+				EMIT1(0x50); /* push rax */
+			if (dst_reg != BPF_REG_3)
+				EMIT1(0x52); /* push rdx */
 
 			/* mov r11, dst_reg */
 			EMIT_mov(AUX_REG, dst_reg);
 
 			if (BPF_SRC(insn->code) == BPF_X)
-				/* mov rax, src_reg */
-				EMIT_mov(BPF_REG_0, src_reg);
+				emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
 			else
-				/* mov rax, imm32 */
-				EMIT3_off32(0x48, 0xC7, 0xC0, imm32);
+				emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
 
-			if (BPF_CLASS(insn->code) == BPF_ALU64)
+			if (is64)
 				EMIT1(add_1mod(0x48, AUX_REG));
 			else if (is_ereg(AUX_REG))
 				EMIT1(add_1mod(0x40, AUX_REG));
 			/* mul(q) r11 */
 			EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
 
-			/* mov r11, rax */
-			EMIT_mov(AUX_REG, BPF_REG_0);
-
-			EMIT1(0x5A); /* pop rdx */
-			EMIT1(0x58); /* pop rax */
-
-			/* mov dst_reg, r11 */
-			EMIT_mov(dst_reg, AUX_REG);
+			if (dst_reg != BPF_REG_3)
+				EMIT1(0x5A); /* pop rdx */
+			if (dst_reg != BPF_REG_0) {
+				/* mov dst_reg, rax */
+				EMIT_mov(dst_reg, BPF_REG_0);
+				EMIT1(0x58); /* pop rax */
+			}
 			break;
-
+		}
 			/* shifts */
 		case BPF_ALU | BPF_LSH | BPF_K:
 		case BPF_ALU | BPF_RSH | BPF_K:
@@ -640,7 +681,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 			case BPF_RSH: b3 = 0xE8; break;
 			case BPF_ARSH: b3 = 0xF8; break;
 			}
-			EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
+
+			if (imm32 == 1)
+				EMIT2(0xD1, add_1reg(b3, dst_reg));
+			else
+				EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
 			break;
 
 		case BPF_ALU | BPF_LSH | BPF_X:

+ 89 - 0
tools/testing/selftests/bpf/test_verifier.c

@@ -11140,6 +11140,95 @@ static struct bpf_test tests[] = {
 		.result = REJECT,
 		.prog_type = BPF_PROG_TYPE_TRACEPOINT,
 	},
+	{
+		"jit: lsh, rsh, arsh by 1",
+		.insns = {
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_MOV64_IMM(BPF_REG_1, 0xff),
+			BPF_ALU64_IMM(BPF_LSH, BPF_REG_1, 1),
+			BPF_ALU32_IMM(BPF_LSH, BPF_REG_1, 1),
+			BPF_JMP_IMM(BPF_JEQ, BPF_REG_1, 0x3fc, 1),
+			BPF_EXIT_INSN(),
+			BPF_ALU64_IMM(BPF_RSH, BPF_REG_1, 1),
+			BPF_ALU32_IMM(BPF_RSH, BPF_REG_1, 1),
+			BPF_JMP_IMM(BPF_JEQ, BPF_REG_1, 0xff, 1),
+			BPF_EXIT_INSN(),
+			BPF_ALU64_IMM(BPF_ARSH, BPF_REG_1, 1),
+			BPF_JMP_IMM(BPF_JEQ, BPF_REG_1, 0x7f, 1),
+			BPF_EXIT_INSN(),
+			BPF_MOV64_IMM(BPF_REG_0, 2),
+			BPF_EXIT_INSN(),
+		},
+		.result = ACCEPT,
+		.retval = 2,
+	},
+	{
+		"jit: mov32 for ldimm64, 1",
+		.insns = {
+			BPF_MOV64_IMM(BPF_REG_0, 2),
+			BPF_LD_IMM64(BPF_REG_1, 0xfeffffffffffffffULL),
+			BPF_ALU64_IMM(BPF_RSH, BPF_REG_1, 32),
+			BPF_LD_IMM64(BPF_REG_2, 0xfeffffffULL),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_1, BPF_REG_2, 1),
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_EXIT_INSN(),
+		},
+		.result = ACCEPT,
+		.retval = 2,
+	},
+	{
+		"jit: mov32 for ldimm64, 2",
+		.insns = {
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_LD_IMM64(BPF_REG_1, 0x1ffffffffULL),
+			BPF_LD_IMM64(BPF_REG_2, 0xffffffffULL),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_1, BPF_REG_2, 1),
+			BPF_MOV64_IMM(BPF_REG_0, 2),
+			BPF_EXIT_INSN(),
+		},
+		.result = ACCEPT,
+		.retval = 2,
+	},
+	{
+		"jit: various mul tests",
+		.insns = {
+			BPF_LD_IMM64(BPF_REG_2, 0xeeff0d413122ULL),
+			BPF_LD_IMM64(BPF_REG_0, 0xfefefeULL),
+			BPF_LD_IMM64(BPF_REG_1, 0xefefefULL),
+			BPF_ALU64_REG(BPF_MUL, BPF_REG_0, BPF_REG_1),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_0, BPF_REG_2, 2),
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_EXIT_INSN(),
+			BPF_LD_IMM64(BPF_REG_3, 0xfefefeULL),
+			BPF_ALU64_REG(BPF_MUL, BPF_REG_3, BPF_REG_1),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_3, BPF_REG_2, 2),
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_EXIT_INSN(),
+			BPF_MOV32_REG(BPF_REG_2, BPF_REG_2),
+			BPF_LD_IMM64(BPF_REG_0, 0xfefefeULL),
+			BPF_ALU32_REG(BPF_MUL, BPF_REG_0, BPF_REG_1),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_0, BPF_REG_2, 2),
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_EXIT_INSN(),
+			BPF_LD_IMM64(BPF_REG_3, 0xfefefeULL),
+			BPF_ALU32_REG(BPF_MUL, BPF_REG_3, BPF_REG_1),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_3, BPF_REG_2, 2),
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_EXIT_INSN(),
+			BPF_LD_IMM64(BPF_REG_0, 0x952a7bbcULL),
+			BPF_LD_IMM64(BPF_REG_1, 0xfefefeULL),
+			BPF_LD_IMM64(BPF_REG_2, 0xeeff0d413122ULL),
+			BPF_ALU32_REG(BPF_MUL, BPF_REG_2, BPF_REG_1),
+			BPF_JMP_REG(BPF_JEQ, BPF_REG_2, BPF_REG_0, 2),
+			BPF_MOV64_IMM(BPF_REG_0, 1),
+			BPF_EXIT_INSN(),
+			BPF_MOV64_IMM(BPF_REG_0, 2),
+			BPF_EXIT_INSN(),
+		},
+		.result = ACCEPT,
+		.retval = 2,
+	},
+
 };
 
 static int probe_filter_length(const struct bpf_insn *fp)