Browse Source

Merge branch 'bpf-more-sock_ops-callbacks'

Lawrence Brakmo says:

====================
This patchset adds support for:

- direct R or R/W access to many tcp_sock fields
- passing up to 4 arguments to sock_ops BPF functions
- tcp_sock field bpf_sock_ops_cb_flags for controlling callbacks
- optionally calling sock_ops BPF program when RTO fires
- optionally calling sock_ops BPF program when packet is retransmitted
- optionally calling sock_ops BPF program when TCP state changes
- access to tclass and sk_txhash
- new selftest

v2: Fixed commit message 0/11. The commit is to "bpf-next" but the patch
    below used "bpf" and Patchwork didn't work correctly.
v3: Cleaned RTO callback as per  Yuchung's comment
    Added BPF enum for TCP states as per  Alexei's comment
v4: Fixed compile warnings related to detecting changes between TCP
    internal states and the BPF defined states.
v5: Fixed comment issues in some selftest files
    Fixed accesss issue with u64 fields in bpf_sock_ops struct
v6: Made fixes based on comments form Eric Dumazet:
    The field bpf_sock_ops_cb_flags was addded in a hole on 64bit kernels
    Field bpf_sock_ops_cb_flags is now set through a helper function
    which returns an error when a BPF program tries to set bits for
    callbacks that are not supported in the current kernel.
    Added a comment indicating that when adding fields to bpf_sock_ops_kern
    they should be added before the field named "temp" if they need to be
    cleared before calling the BPF function.
v7: Enfornced fields "op" and "replylong[1] .. replylong[3]" not be writable
    based on comments form Eric Dumazet and Alexei Starovoitov.
    Filled 32 bit hole in bpf_sock_ops struct with sk_txhash based on
    comments from Daniel Borkmann.
    Removed unused functions (tcp_call_bpf_1arg, tcp_call_bpf_4arg) based
    on comments from Daniel Borkmann.
v8: Add commit message 00/12
    Add Acked-by as appropriate
v9: Moved the bug fix to the front of the patchset
    Changed RETRANS_CB so it is always called (before it was only called if
    the retransmit succeeded). It is now called with an extra argument, the
    return value of tcp_transmit_skb (0 => success). Based on comments
    from Yuchung Cheng.
    Added support for reading 2 new fields, sacked_out and lost_out, based on
    comments from Yuchung Cheng.
v10: Moved the callback flags from include/uapi/linux/tcp.h to
     include/uapi/linux/bpf.h
     Cleaned up the test in selftest. Added a timeout so it always completes,
     even if the client is not communicating with the server. Made it faster
     by removing the sleeps. Made sure it works even when called back-to-back
     20 times.

Consists of the following patches:
[PATCH bpf-next v10 01/12] bpf: Only reply field should be writeable
[PATCH bpf-next v10 02/12] bpf: Make SOCK_OPS_GET_TCP size
[PATCH bpf-next v10 03/12] bpf: Make SOCK_OPS_GET_TCP struct
[PATCH bpf-next v10 04/12] bpf: Add write access to tcp_sock and sock
[PATCH bpf-next v10 05/12] bpf: Support passing args to sock_ops bpf
[PATCH bpf-next v10 06/12] bpf: Adds field bpf_sock_ops_cb_flags to
[PATCH bpf-next v10 07/12] bpf: Add sock_ops RTO callback
[PATCH bpf-next v10 08/12] bpf: Add support for reading sk_state and
[PATCH bpf-next v10 09/12] bpf: Add sock_ops R/W access to tclass
[PATCH bpf-next v10 10/12] bpf: Add BPF_SOCK_OPS_RETRANS_CB
[PATCH bpf-next v10 11/12] bpf: Add BPF_SOCK_OPS_STATE_CB
[PATCH bpf-next v10 12/12] bpf: add selftest for tcpbpf
====================

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Alexei Starovoitov 7 years ago
parent
commit
82f1e0f3ac

+ 10 - 0
include/linux/filter.h

@@ -1003,10 +1003,20 @@ struct bpf_sock_ops_kern {
 	struct	sock *sk;
 	u32	op;
 	union {
+		u32 args[4];
 		u32 reply;
 		u32 replylong[4];
 	};
 	u32	is_fullsock;
+	u64	temp;			/* temp and everything after is not
+					 * initialized to 0 before calling
+					 * the BPF program. New fields that
+					 * should be initialized to 0 should
+					 * be inserted before temp.
+					 * temp is scratch storage used by
+					 * sock_ops_convert_ctx_access
+					 * as temporary storage of a register.
+					 */
 };
 
 #endif /* __LINUX_FILTER_H__ */

+ 11 - 0
include/linux/tcp.h

@@ -335,6 +335,17 @@ struct tcp_sock {
 
 	int			linger2;
 
+
+/* Sock_ops bpf program related variables */
+#ifdef CONFIG_BPF
+	u8	bpf_sock_ops_cb_flags;  /* Control calling BPF programs
+					 * values defined in uapi/linux/tcp.h
+					 */
+#define BPF_SOCK_OPS_TEST_FLAG(TP, ARG) (TP->bpf_sock_ops_cb_flags & ARG)
+#else
+#define BPF_SOCK_OPS_TEST_FLAG(TP, ARG) 0
+#endif
+
 /* Receiver side RTT estimation */
 	struct {
 		u32	rtt_us;

+ 36 - 6
include/net/tcp.h

@@ -2006,12 +2006,12 @@ void tcp_cleanup_ulp(struct sock *sk);
  * program loaded).
  */
 #ifdef CONFIG_BPF
-static inline int tcp_call_bpf(struct sock *sk, int op)
+static inline int tcp_call_bpf(struct sock *sk, int op, u32 nargs, u32 *args)
 {
 	struct bpf_sock_ops_kern sock_ops;
 	int ret;
 
-	memset(&sock_ops, 0, sizeof(sock_ops));
+	memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp));
 	if (sk_fullsock(sk)) {
 		sock_ops.is_fullsock = 1;
 		sock_owned_by_me(sk);
@@ -2019,6 +2019,8 @@ static inline int tcp_call_bpf(struct sock *sk, int op)
 
 	sock_ops.sk = sk;
 	sock_ops.op = op;
+	if (nargs > 0)
+		memcpy(sock_ops.args, args, nargs * sizeof(*args));
 
 	ret = BPF_CGROUP_RUN_PROG_SOCK_OPS(&sock_ops);
 	if (ret == 0)
@@ -2027,18 +2029,46 @@ static inline int tcp_call_bpf(struct sock *sk, int op)
 		ret = -1;
 	return ret;
 }
+
+static inline int tcp_call_bpf_2arg(struct sock *sk, int op, u32 arg1, u32 arg2)
+{
+	u32 args[2] = {arg1, arg2};
+
+	return tcp_call_bpf(sk, op, 2, args);
+}
+
+static inline int tcp_call_bpf_3arg(struct sock *sk, int op, u32 arg1, u32 arg2,
+				    u32 arg3)
+{
+	u32 args[3] = {arg1, arg2, arg3};
+
+	return tcp_call_bpf(sk, op, 3, args);
+}
+
 #else
-static inline int tcp_call_bpf(struct sock *sk, int op)
+static inline int tcp_call_bpf(struct sock *sk, int op, u32 nargs, u32 *args)
 {
 	return -EPERM;
 }
+
+static inline int tcp_call_bpf_2arg(struct sock *sk, int op, u32 arg1, u32 arg2)
+{
+	return -EPERM;
+}
+
+static inline int tcp_call_bpf_3arg(struct sock *sk, int op, u32 arg1, u32 arg2,
+				    u32 arg3)
+{
+	return -EPERM;
+}
+
 #endif
 
 static inline u32 tcp_timeout_init(struct sock *sk)
 {
 	int timeout;
 
-	timeout = tcp_call_bpf(sk, BPF_SOCK_OPS_TIMEOUT_INIT);
+	timeout = tcp_call_bpf(sk, BPF_SOCK_OPS_TIMEOUT_INIT, 0, NULL);
 
 	if (timeout <= 0)
 		timeout = TCP_TIMEOUT_INIT;
@@ -2049,7 +2079,7 @@ static inline u32 tcp_rwnd_init_bpf(struct sock *sk)
 {
 	int rwnd;
 
-	rwnd = tcp_call_bpf(sk, BPF_SOCK_OPS_RWND_INIT);
+	rwnd = tcp_call_bpf(sk, BPF_SOCK_OPS_RWND_INIT, 0, NULL);
 
 	if (rwnd < 0)
 		rwnd = 0;
@@ -2058,7 +2088,7 @@ static inline u32 tcp_rwnd_init_bpf(struct sock *sk)
 
 static inline bool tcp_bpf_ca_needs_ecn(struct sock *sk)
 {
-	return (tcp_call_bpf(sk, BPF_SOCK_OPS_NEEDS_ECN) == 1);
+	return (tcp_call_bpf(sk, BPF_SOCK_OPS_NEEDS_ECN, 0, NULL) == 1);
 }
 
 #if IS_ENABLED(CONFIG_SMC)

+ 81 - 3
include/uapi/linux/bpf.h

@@ -642,6 +642,14 @@ union bpf_attr {
  *     @optlen: length of optval in bytes
  *     Return: 0 or negative error
  *
+ * int bpf_sock_ops_cb_flags_set(bpf_sock_ops, flags)
+ *     Set callback flags for sock_ops
+ *     @bpf_sock_ops: pointer to bpf_sock_ops_kern struct
+ *     @flags: flags value
+ *     Return: 0 for no error
+ *             -EINVAL if there is no full tcp socket
+ *             bits in flags that are not supported by current kernel
+ *
  * int bpf_skb_adjust_room(skb, len_diff, mode, flags)
  *     Grow or shrink room in sk_buff.
  *     @skb: pointer to skb
@@ -748,7 +756,8 @@ union bpf_attr {
 	FN(perf_event_read_value),	\
 	FN(perf_prog_read_value),	\
 	FN(getsockopt),			\
-	FN(override_return),
+	FN(override_return),		\
+	FN(sock_ops_cb_flags_set),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
@@ -952,8 +961,9 @@ struct bpf_map_info {
 struct bpf_sock_ops {
 	__u32 op;
 	union {
-		__u32 reply;
-		__u32 replylong[4];
+		__u32 args[4];		/* Optionally passed to bpf program */
+		__u32 reply;		/* Returned by bpf program	    */
+		__u32 replylong[4];	/* Optionally returned by bpf prog  */
 	};
 	__u32 family;
 	__u32 remote_ip4;	/* Stored in network byte order */
@@ -968,8 +978,39 @@ struct bpf_sock_ops {
 				 */
 	__u32 snd_cwnd;
 	__u32 srtt_us;		/* Averaged RTT << 3 in usecs */
+	__u32 bpf_sock_ops_cb_flags; /* flags defined in uapi/linux/tcp.h */
+	__u32 state;
+	__u32 rtt_min;
+	__u32 snd_ssthresh;
+	__u32 rcv_nxt;
+	__u32 snd_nxt;
+	__u32 snd_una;
+	__u32 mss_cache;
+	__u32 ecn_flags;
+	__u32 rate_delivered;
+	__u32 rate_interval_us;
+	__u32 packets_out;
+	__u32 retrans_out;
+	__u32 total_retrans;
+	__u32 segs_in;
+	__u32 data_segs_in;
+	__u32 segs_out;
+	__u32 data_segs_out;
+	__u32 lost_out;
+	__u32 sacked_out;
+	__u32 sk_txhash;
+	__u64 bytes_received;
+	__u64 bytes_acked;
 };
 
+/* Definitions for bpf_sock_ops_cb_flags */
+#define BPF_SOCK_OPS_RTO_CB_FLAG	(1<<0)
+#define BPF_SOCK_OPS_RETRANS_CB_FLAG	(1<<1)
+#define BPF_SOCK_OPS_STATE_CB_FLAG	(1<<2)
+#define BPF_SOCK_OPS_ALL_CB_FLAGS       0x7		/* Mask of all currently
+							 * supported cb flags
+							 */
+
 /* List of known BPF sock_ops operators.
  * New entries can only be added at the end
  */
@@ -1003,6 +1044,43 @@ enum {
 					 * a congestion threshold. RTTs above
 					 * this indicate congestion
 					 */
+	BPF_SOCK_OPS_RTO_CB,		/* Called when an RTO has triggered.
+					 * Arg1: value of icsk_retransmits
+					 * Arg2: value of icsk_rto
+					 * Arg3: whether RTO has expired
+					 */
+	BPF_SOCK_OPS_RETRANS_CB,	/* Called when skb is retransmitted.
+					 * Arg1: sequence number of 1st byte
+					 * Arg2: # segments
+					 * Arg3: return value of
+					 *       tcp_transmit_skb (0 => success)
+					 */
+	BPF_SOCK_OPS_STATE_CB,		/* Called when TCP changes state.
+					 * Arg1: old_state
+					 * Arg2: new_state
+					 */
+};
+
+/* List of TCP states. There is a build check in net/ipv4/tcp.c to detect
+ * changes between the TCP and BPF versions. Ideally this should never happen.
+ * If it does, we need to add code to convert them before calling
+ * the BPF sock_ops function.
+ */
+enum {
+	BPF_TCP_ESTABLISHED = 1,
+	BPF_TCP_SYN_SENT,
+	BPF_TCP_SYN_RECV,
+	BPF_TCP_FIN_WAIT1,
+	BPF_TCP_FIN_WAIT2,
+	BPF_TCP_TIME_WAIT,
+	BPF_TCP_CLOSE,
+	BPF_TCP_CLOSE_WAIT,
+	BPF_TCP_LAST_ACK,
+	BPF_TCP_LISTEN,
+	BPF_TCP_CLOSING,	/* Now a valid state */
+	BPF_TCP_NEW_SYN_RECV,
+
+	BPF_TCP_MAX_STATES	/* Leave at the end! */
 };
 
 #define TCP_BPF_IW		1001	/* Set TCP initial congestion window */

+ 269 - 21
net/core/filter.c

@@ -3232,6 +3232,29 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
 			ret = -EINVAL;
 		}
 #ifdef CONFIG_INET
+#if IS_ENABLED(CONFIG_IPV6)
+	} else if (level == SOL_IPV6) {
+		if (optlen != sizeof(int) || sk->sk_family != AF_INET6)
+			return -EINVAL;
+
+		val = *((int *)optval);
+		/* Only some options are supported */
+		switch (optname) {
+		case IPV6_TCLASS:
+			if (val < -1 || val > 0xff) {
+				ret = -EINVAL;
+			} else {
+				struct ipv6_pinfo *np = inet6_sk(sk);
+
+				if (val == -1)
+					val = 0;
+				np->tclass = val;
+			}
+			break;
+		default:
+			ret = -EINVAL;
+		}
+#endif
 	} else if (level == SOL_TCP &&
 		   sk->sk_prot->setsockopt == tcp_setsockopt) {
 		if (optname == TCP_CONGESTION) {
@@ -3241,7 +3264,8 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
 			strncpy(name, optval, min_t(long, optlen,
 						    TCP_CA_NAME_MAX-1));
 			name[TCP_CA_NAME_MAX-1] = 0;
-			ret = tcp_set_congestion_control(sk, name, false, reinit);
+			ret = tcp_set_congestion_control(sk, name, false,
+							 reinit);
 		} else {
 			struct tcp_sock *tp = tcp_sk(sk);
 
@@ -3307,6 +3331,22 @@ BPF_CALL_5(bpf_getsockopt, struct bpf_sock_ops_kern *, bpf_sock,
 		} else {
 			goto err_clear;
 		}
+#if IS_ENABLED(CONFIG_IPV6)
+	} else if (level == SOL_IPV6) {
+		struct ipv6_pinfo *np = inet6_sk(sk);
+
+		if (optlen != sizeof(int) || sk->sk_family != AF_INET6)
+			goto err_clear;
+
+		/* Only some options are supported */
+		switch (optname) {
+		case IPV6_TCLASS:
+			*((int *)optval) = (int)np->tclass;
+			break;
+		default:
+			goto err_clear;
+		}
+#endif
 	} else {
 		goto err_clear;
 	}
@@ -3328,6 +3368,33 @@ static const struct bpf_func_proto bpf_getsockopt_proto = {
 	.arg5_type	= ARG_CONST_SIZE,
 };
 
+BPF_CALL_2(bpf_sock_ops_cb_flags_set, struct bpf_sock_ops_kern *, bpf_sock,
+	   int, argval)
+{
+	struct sock *sk = bpf_sock->sk;
+	int val = argval & BPF_SOCK_OPS_ALL_CB_FLAGS;
+
+	if (!sk_fullsock(sk))
+		return -EINVAL;
+
+#ifdef CONFIG_INET
+	if (val)
+		tcp_sk(sk)->bpf_sock_ops_cb_flags = val;
+
+	return argval & (~BPF_SOCK_OPS_ALL_CB_FLAGS);
+#else
+	return -EINVAL;
+#endif
+}
+
+static const struct bpf_func_proto bpf_sock_ops_cb_flags_set_proto = {
+	.func		= bpf_sock_ops_cb_flags_set,
+	.gpl_only	= false,
+	.ret_type	= RET_INTEGER,
+	.arg1_type	= ARG_PTR_TO_CTX,
+	.arg2_type	= ARG_ANYTHING,
+};
+
 static const struct bpf_func_proto *
 bpf_base_func_proto(enum bpf_func_id func_id)
 {
@@ -3510,6 +3577,8 @@ static const struct bpf_func_proto *
 		return &bpf_setsockopt_proto;
 	case BPF_FUNC_getsockopt:
 		return &bpf_getsockopt_proto;
+	case BPF_FUNC_sock_ops_cb_flags_set:
+		return &bpf_sock_ops_cb_flags_set_proto;
 	case BPF_FUNC_sock_map_update:
 		return &bpf_sock_map_update_proto;
 	default:
@@ -3826,34 +3895,44 @@ void bpf_warn_invalid_xdp_action(u32 act)
 }
 EXPORT_SYMBOL_GPL(bpf_warn_invalid_xdp_action);
 
-static bool __is_valid_sock_ops_access(int off, int size)
+static bool sock_ops_is_valid_access(int off, int size,
+				     enum bpf_access_type type,
+				     struct bpf_insn_access_aux *info)
 {
+	const int size_default = sizeof(__u32);
+
 	if (off < 0 || off >= sizeof(struct bpf_sock_ops))
 		return false;
+
 	/* The verifier guarantees that size > 0. */
 	if (off % size != 0)
 		return false;
-	if (size != sizeof(__u32))
-		return false;
 
-	return true;
-}
-
-static bool sock_ops_is_valid_access(int off, int size,
-				     enum bpf_access_type type,
-				     struct bpf_insn_access_aux *info)
-{
 	if (type == BPF_WRITE) {
 		switch (off) {
-		case offsetof(struct bpf_sock_ops, op) ...
-		     offsetof(struct bpf_sock_ops, replylong[3]):
+		case offsetof(struct bpf_sock_ops, reply):
+		case offsetof(struct bpf_sock_ops, sk_txhash):
+			if (size != size_default)
+				return false;
 			break;
 		default:
 			return false;
 		}
+	} else {
+		switch (off) {
+		case bpf_ctx_range_till(struct bpf_sock_ops, bytes_received,
+					bytes_acked):
+			if (size != sizeof(__u64))
+				return false;
+			break;
+		default:
+			if (size != size_default)
+				return false;
+			break;
+		}
 	}
 
-	return __is_valid_sock_ops_access(off, size);
+	return true;
 }
 
 static int sk_skb_prologue(struct bpf_insn *insn_buf, bool direct_write,
@@ -4470,10 +4549,37 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type,
 					       is_fullsock));
 		break;
 
-/* Helper macro for adding read access to tcp_sock fields. */
-#define SOCK_OPS_GET_TCP32(FIELD_NAME)					      \
+	case offsetof(struct bpf_sock_ops, state):
+		BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_state) != 1);
+
+		*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
+						struct bpf_sock_ops_kern, sk),
+				      si->dst_reg, si->src_reg,
+				      offsetof(struct bpf_sock_ops_kern, sk));
+		*insn++ = BPF_LDX_MEM(BPF_B, si->dst_reg, si->dst_reg,
+				      offsetof(struct sock_common, skc_state));
+		break;
+
+	case offsetof(struct bpf_sock_ops, rtt_min):
+		BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, rtt_min) !=
+			     sizeof(struct minmax));
+		BUILD_BUG_ON(sizeof(struct minmax) <
+			     sizeof(struct minmax_sample));
+
+		*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
+						struct bpf_sock_ops_kern, sk),
+				      si->dst_reg, si->src_reg,
+				      offsetof(struct bpf_sock_ops_kern, sk));
+		*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+				      offsetof(struct tcp_sock, rtt_min) +
+				      FIELD_SIZEOF(struct minmax_sample, t));
+		break;
+
+/* Helper macro for adding read access to tcp_sock or sock fields. */
+#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)			      \
 	do {								      \
-		BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, FIELD_NAME) != 4); \
+		BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >		      \
+			     FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
 		*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(			      \
 						struct bpf_sock_ops_kern,     \
 						is_fullsock),		      \
@@ -4485,17 +4591,159 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type,
 						struct bpf_sock_ops_kern, sk),\
 				      si->dst_reg, si->src_reg,		      \
 				      offsetof(struct bpf_sock_ops_kern, sk));\
-		*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,        \
-				      offsetof(struct tcp_sock, FIELD_NAME)); \
+		*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ,		      \
+						       OBJ_FIELD),	      \
+				      si->dst_reg, si->dst_reg,		      \
+				      offsetof(OBJ, OBJ_FIELD));	      \
+	} while (0)
+
+/* Helper macro for adding write access to tcp_sock or sock fields.
+ * The macro is called with two registers, dst_reg which contains a pointer
+ * to ctx (context) and src_reg which contains the value that should be
+ * stored. However, we need an additional register since we cannot overwrite
+ * dst_reg because it may be used later in the program.
+ * Instead we "borrow" one of the other register. We first save its value
+ * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore
+ * it at the end of the macro.
+ */
+#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)			      \
+	do {								      \
+		int reg = BPF_REG_9;					      \
+		BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >		      \
+			     FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
+		if (si->dst_reg == reg || si->src_reg == reg)		      \
+			reg--;						      \
+		if (si->dst_reg == reg || si->src_reg == reg)		      \
+			reg--;						      \
+		*insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg,		      \
+				      offsetof(struct bpf_sock_ops_kern,      \
+					       temp));			      \
+		*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(			      \
+						struct bpf_sock_ops_kern,     \
+						is_fullsock),		      \
+				      reg, si->dst_reg,			      \
+				      offsetof(struct bpf_sock_ops_kern,      \
+					       is_fullsock));		      \
+		*insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2);		      \
+		*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(			      \
+						struct bpf_sock_ops_kern, sk),\
+				      reg, si->dst_reg,			      \
+				      offsetof(struct bpf_sock_ops_kern, sk));\
+		*insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD),	      \
+				      reg, si->src_reg,			      \
+				      offsetof(OBJ, OBJ_FIELD));	      \
+		*insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg,		      \
+				      offsetof(struct bpf_sock_ops_kern,      \
+					       temp));			      \
+	} while (0)
+
+#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE)	      \
+	do {								      \
+		if (TYPE == BPF_WRITE)					      \
+			SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);	      \
+		else							      \
+			SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);	      \
 	} while (0)
 
 	case offsetof(struct bpf_sock_ops, snd_cwnd):
-		SOCK_OPS_GET_TCP32(snd_cwnd);
+		SOCK_OPS_GET_FIELD(snd_cwnd, snd_cwnd, struct tcp_sock);
 		break;
 
 	case offsetof(struct bpf_sock_ops, srtt_us):
-		SOCK_OPS_GET_TCP32(srtt_us);
+		SOCK_OPS_GET_FIELD(srtt_us, srtt_us, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, bpf_sock_ops_cb_flags):
+		SOCK_OPS_GET_FIELD(bpf_sock_ops_cb_flags, bpf_sock_ops_cb_flags,
+				   struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, snd_ssthresh):
+		SOCK_OPS_GET_FIELD(snd_ssthresh, snd_ssthresh, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, rcv_nxt):
+		SOCK_OPS_GET_FIELD(rcv_nxt, rcv_nxt, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, snd_nxt):
+		SOCK_OPS_GET_FIELD(snd_nxt, snd_nxt, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, snd_una):
+		SOCK_OPS_GET_FIELD(snd_una, snd_una, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, mss_cache):
+		SOCK_OPS_GET_FIELD(mss_cache, mss_cache, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, ecn_flags):
+		SOCK_OPS_GET_FIELD(ecn_flags, ecn_flags, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, rate_delivered):
+		SOCK_OPS_GET_FIELD(rate_delivered, rate_delivered,
+				   struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, rate_interval_us):
+		SOCK_OPS_GET_FIELD(rate_interval_us, rate_interval_us,
+				   struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, packets_out):
+		SOCK_OPS_GET_FIELD(packets_out, packets_out, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, retrans_out):
+		SOCK_OPS_GET_FIELD(retrans_out, retrans_out, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, total_retrans):
+		SOCK_OPS_GET_FIELD(total_retrans, total_retrans,
+				   struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, segs_in):
+		SOCK_OPS_GET_FIELD(segs_in, segs_in, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, data_segs_in):
+		SOCK_OPS_GET_FIELD(data_segs_in, data_segs_in, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, segs_out):
+		SOCK_OPS_GET_FIELD(segs_out, segs_out, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, data_segs_out):
+		SOCK_OPS_GET_FIELD(data_segs_out, data_segs_out,
+				   struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, lost_out):
+		SOCK_OPS_GET_FIELD(lost_out, lost_out, struct tcp_sock);
 		break;
+
+	case offsetof(struct bpf_sock_ops, sacked_out):
+		SOCK_OPS_GET_FIELD(sacked_out, sacked_out, struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, sk_txhash):
+		SOCK_OPS_GET_OR_SET_FIELD(sk_txhash, sk_txhash,
+					  struct sock, type);
+		break;
+
+	case offsetof(struct bpf_sock_ops, bytes_received):
+		SOCK_OPS_GET_FIELD(bytes_received, bytes_received,
+				   struct tcp_sock);
+		break;
+
+	case offsetof(struct bpf_sock_ops, bytes_acked):
+		SOCK_OPS_GET_FIELD(bytes_acked, bytes_acked, struct tcp_sock);
+		break;
+
 	}
 	return insn - insn_buf;
 }

+ 25 - 1
net/ipv4/tcp.c

@@ -463,7 +463,7 @@ void tcp_init_transfer(struct sock *sk, int bpf_op)
 	tcp_mtup_init(sk);
 	icsk->icsk_af_ops->rebuild_header(sk);
 	tcp_init_metrics(sk);
-	tcp_call_bpf(sk, bpf_op);
+	tcp_call_bpf(sk, bpf_op, 0, NULL);
 	tcp_init_congestion_control(sk);
 	tcp_init_buffer_space(sk);
 }
@@ -2042,6 +2042,30 @@ void tcp_set_state(struct sock *sk, int state)
 {
 	int oldstate = sk->sk_state;
 
+	/* We defined a new enum for TCP states that are exported in BPF
+	 * so as not force the internal TCP states to be frozen. The
+	 * following checks will detect if an internal state value ever
+	 * differs from the BPF value. If this ever happens, then we will
+	 * need to remap the internal value to the BPF value before calling
+	 * tcp_call_bpf_2arg.
+	 */
+	BUILD_BUG_ON((int)BPF_TCP_ESTABLISHED != (int)TCP_ESTABLISHED);
+	BUILD_BUG_ON((int)BPF_TCP_SYN_SENT != (int)TCP_SYN_SENT);
+	BUILD_BUG_ON((int)BPF_TCP_SYN_RECV != (int)TCP_SYN_RECV);
+	BUILD_BUG_ON((int)BPF_TCP_FIN_WAIT1 != (int)TCP_FIN_WAIT1);
+	BUILD_BUG_ON((int)BPF_TCP_FIN_WAIT2 != (int)TCP_FIN_WAIT2);
+	BUILD_BUG_ON((int)BPF_TCP_TIME_WAIT != (int)TCP_TIME_WAIT);
+	BUILD_BUG_ON((int)BPF_TCP_CLOSE != (int)TCP_CLOSE);
+	BUILD_BUG_ON((int)BPF_TCP_CLOSE_WAIT != (int)TCP_CLOSE_WAIT);
+	BUILD_BUG_ON((int)BPF_TCP_LAST_ACK != (int)TCP_LAST_ACK);
+	BUILD_BUG_ON((int)BPF_TCP_LISTEN != (int)TCP_LISTEN);
+	BUILD_BUG_ON((int)BPF_TCP_CLOSING != (int)TCP_CLOSING);
+	BUILD_BUG_ON((int)BPF_TCP_NEW_SYN_RECV != (int)TCP_NEW_SYN_RECV);
+	BUILD_BUG_ON((int)BPF_TCP_MAX_STATES != (int)TCP_MAX_STATES);
+
+	if (BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk), BPF_SOCK_OPS_STATE_CB_FLAG))
+		tcp_call_bpf_2arg(sk, BPF_SOCK_OPS_STATE_CB, oldstate, state);
+
 	switch (state) {
 	case TCP_ESTABLISHED:
 		if (oldstate != TCP_ESTABLISHED)

+ 1 - 1
net/ipv4/tcp_nv.c

@@ -146,7 +146,7 @@ static void tcpnv_init(struct sock *sk)
 	 * within a datacenter, where we have reasonable estimates of
 	 * RTTs
 	 */
-	base_rtt = tcp_call_bpf(sk, BPF_SOCK_OPS_BASE_RTT);
+	base_rtt = tcp_call_bpf(sk, BPF_SOCK_OPS_BASE_RTT, 0, NULL);
 	if (base_rtt > 0) {
 		ca->nv_base_rtt = base_rtt;
 		ca->nv_lower_bound_rtt = (base_rtt * 205) >> 8; /* 80% */

+ 5 - 1
net/ipv4/tcp_output.c

@@ -2905,6 +2905,10 @@ int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs)
 		err = tcp_transmit_skb(sk, skb, 1, GFP_ATOMIC);
 	}
 
+	if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RETRANS_CB_FLAG))
+		tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RETRANS_CB,
+				  TCP_SKB_CB(skb)->seq, segs, err);
+
 	if (likely(!err)) {
 		TCP_SKB_CB(skb)->sacked |= TCPCB_EVER_RETRANS;
 		trace_tcp_retransmit_skb(sk, skb);
@@ -3469,7 +3473,7 @@ int tcp_connect(struct sock *sk)
 	struct sk_buff *buff;
 	int err;
 
-	tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_CONNECT_CB);
+	tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_CONNECT_CB, 0, NULL);
 
 	if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk))
 		return -EHOSTUNREACH; /* Routing failure or similar. */

+ 7 - 0
net/ipv4/tcp_timer.c

@@ -213,11 +213,18 @@ static int tcp_write_timeout(struct sock *sk)
 						icsk->icsk_user_timeout);
 	}
 	tcp_fastopen_active_detect_blackhole(sk, expired);
+
+	if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RTO_CB_FLAG))
+		tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RTO_CB,
+				  icsk->icsk_retransmits,
+				  icsk->icsk_rto, (int)expired);
+
 	if (expired) {
 		/* Has it gone just too far? */
 		tcp_write_err(sk);
 		return 1;
 	}
+
 	return 0;
 }
 

+ 82 - 4
tools/include/uapi/linux/bpf.h

@@ -17,7 +17,7 @@
 #define BPF_ALU64	0x07	/* alu mode in double word width */
 
 /* ld/ldx fields */
-#define BPF_DW		0x18	/* double word */
+#define BPF_DW		0x18	/* double word (64-bit) */
 #define BPF_XADD	0xc0	/* exclusive add */
 
 /* alu/jmp fields */
@@ -642,6 +642,14 @@ union bpf_attr {
  *     @optlen: length of optval in bytes
  *     Return: 0 or negative error
  *
+ * int bpf_sock_ops_cb_flags_set(bpf_sock_ops, flags)
+ *     Set callback flags for sock_ops
+ *     @bpf_sock_ops: pointer to bpf_sock_ops_kern struct
+ *     @flags: flags value
+ *     Return: 0 for no error
+ *             -EINVAL if there is no full tcp socket
+ *             bits in flags that are not supported by current kernel
+ *
  * int bpf_skb_adjust_room(skb, len_diff, mode, flags)
  *     Grow or shrink room in sk_buff.
  *     @skb: pointer to skb
@@ -748,7 +756,8 @@ union bpf_attr {
 	FN(perf_event_read_value),	\
 	FN(perf_prog_read_value),	\
 	FN(getsockopt),			\
-	FN(override_return),
+	FN(override_return),		\
+	FN(sock_ops_cb_flags_set),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
@@ -952,8 +961,9 @@ struct bpf_map_info {
 struct bpf_sock_ops {
 	__u32 op;
 	union {
-		__u32 reply;
-		__u32 replylong[4];
+		__u32 args[4];		/* Optionally passed to bpf program */
+		__u32 reply;		/* Returned by bpf program	    */
+		__u32 replylong[4];	/* Optionally returned by bpf prog  */
 	};
 	__u32 family;
 	__u32 remote_ip4;	/* Stored in network byte order */
@@ -968,8 +978,39 @@ struct bpf_sock_ops {
 				 */
 	__u32 snd_cwnd;
 	__u32 srtt_us;		/* Averaged RTT << 3 in usecs */
+	__u32 bpf_sock_ops_cb_flags; /* flags defined in uapi/linux/tcp.h */
+	__u32 state;
+	__u32 rtt_min;
+	__u32 snd_ssthresh;
+	__u32 rcv_nxt;
+	__u32 snd_nxt;
+	__u32 snd_una;
+	__u32 mss_cache;
+	__u32 ecn_flags;
+	__u32 rate_delivered;
+	__u32 rate_interval_us;
+	__u32 packets_out;
+	__u32 retrans_out;
+	__u32 total_retrans;
+	__u32 segs_in;
+	__u32 data_segs_in;
+	__u32 segs_out;
+	__u32 data_segs_out;
+	__u32 lost_out;
+	__u32 sacked_out;
+	__u32 sk_txhash;
+	__u64 bytes_received;
+	__u64 bytes_acked;
 };
 
+/* Definitions for bpf_sock_ops_cb_flags */
+#define BPF_SOCK_OPS_RTO_CB_FLAG	(1<<0)
+#define BPF_SOCK_OPS_RETRANS_CB_FLAG	(1<<1)
+#define BPF_SOCK_OPS_STATE_CB_FLAG	(1<<2)
+#define BPF_SOCK_OPS_ALL_CB_FLAGS       0x7		/* Mask of all currently
+							 * supported cb flags
+							 */
+
 /* List of known BPF sock_ops operators.
  * New entries can only be added at the end
  */
@@ -1003,6 +1044,43 @@ enum {
 					 * a congestion threshold. RTTs above
 					 * this indicate congestion
 					 */
+	BPF_SOCK_OPS_RTO_CB,		/* Called when an RTO has triggered.
+					 * Arg1: value of icsk_retransmits
+					 * Arg2: value of icsk_rto
+					 * Arg3: whether RTO has expired
+					 */
+	BPF_SOCK_OPS_RETRANS_CB,	/* Called when skb is retransmitted.
+					 * Arg1: sequence number of 1st byte
+					 * Arg2: # segments
+					 * Arg3: return value of
+					 *       tcp_transmit_skb (0 => success)
+					 */
+	BPF_SOCK_OPS_STATE_CB,		/* Called when TCP changes state.
+					 * Arg1: old_state
+					 * Arg2: new_state
+					 */
+};
+
+/* List of TCP states. There is a build check in net/ipv4/tcp.c to detect
+ * changes between the TCP and BPF versions. Ideally this should never happen.
+ * If it does, we need to add code to convert them before calling
+ * the BPF sock_ops function.
+ */
+enum {
+	BPF_TCP_ESTABLISHED = 1,
+	BPF_TCP_SYN_SENT,
+	BPF_TCP_SYN_RECV,
+	BPF_TCP_FIN_WAIT1,
+	BPF_TCP_FIN_WAIT2,
+	BPF_TCP_TIME_WAIT,
+	BPF_TCP_CLOSE,
+	BPF_TCP_CLOSE_WAIT,
+	BPF_TCP_LAST_ACK,
+	BPF_TCP_LISTEN,
+	BPF_TCP_CLOSING,	/* Now a valid state */
+	BPF_TCP_NEW_SYN_RECV,
+
+	BPF_TCP_MAX_STATES	/* Leave at the end! */
 };
 
 #define TCP_BPF_IW		1001	/* Set TCP initial congestion window */

+ 2 - 2
tools/testing/selftests/bpf/Makefile

@@ -14,13 +14,13 @@ CFLAGS += -Wall -O2 -I$(APIDIR) -I$(LIBDIR) -I$(GENDIR) $(GENFLAGS) -I../../../i
 LDLIBS += -lcap -lelf -lrt
 
 TEST_GEN_PROGS = test_verifier test_tag test_maps test_lru_map test_lpm_map test_progs \
-	test_align test_verifier_log test_dev_cgroup
+	test_align test_verifier_log test_dev_cgroup test_tcpbpf_user
 
 TEST_GEN_FILES = test_pkt_access.o test_xdp.o test_l4lb.o test_tcp_estats.o test_obj_id.o \
 	test_pkt_md_access.o test_xdp_redirect.o test_xdp_meta.o sockmap_parse_prog.o     \
 	sockmap_verdict_prog.o dev_cgroup.o sample_ret0.o test_tracepoint.o \
 	test_l4lb_noinline.o test_xdp_noinline.o test_stacktrace_map.o \
-	sample_map_ret0.o
+	sample_map_ret0.o test_tcpbpf_kern.o
 
 TEST_PROGS := test_kmod.sh test_xdp_redirect.sh test_xdp_meta.sh \
 	test_offload.py

+ 2 - 0
tools/testing/selftests/bpf/bpf_helpers.h

@@ -71,6 +71,8 @@ static int (*bpf_setsockopt)(void *ctx, int level, int optname, void *optval,
 static int (*bpf_getsockopt)(void *ctx, int level, int optname, void *optval,
 			     int optlen) =
 	(void *) BPF_FUNC_getsockopt;
+static int (*bpf_sock_ops_cb_flags_set)(void *ctx, int flags) =
+	(void *) BPF_FUNC_sock_ops_cb_flags_set;
 static int (*bpf_sk_redirect_map)(void *ctx, void *map, int key, int flags) =
 	(void *) BPF_FUNC_sk_redirect_map;
 static int (*bpf_sock_map_update)(void *map, void *key, void *value,

+ 51 - 0
tools/testing/selftests/bpf/tcp_client.py

@@ -0,0 +1,51 @@
+#!/usr/bin/env python2
+#
+# SPDX-License-Identifier: GPL-2.0
+#
+
+import sys, os, os.path, getopt
+import socket, time
+import subprocess
+import select
+
+def read(sock, n):
+    buf = ''
+    while len(buf) < n:
+        rem = n - len(buf)
+        try: s = sock.recv(rem)
+        except (socket.error), e: return ''
+        buf += s
+    return buf
+
+def send(sock, s):
+    total = len(s)
+    count = 0
+    while count < total:
+        try: n = sock.send(s)
+        except (socket.error), e: n = 0
+        if n == 0:
+            return count;
+        count += n
+    return count
+
+
+serverPort = int(sys.argv[1])
+HostName = socket.gethostname()
+
+# create active socket
+sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+try:
+    sock.connect((HostName, serverPort))
+except socket.error as e:
+    sys.exit(1)
+
+buf = ''
+n = 0
+while n < 1000:
+    buf += '+'
+    n += 1
+
+sock.settimeout(1);
+n = send(sock, buf)
+n = read(sock, 500)
+sys.exit(0)

+ 83 - 0
tools/testing/selftests/bpf/tcp_server.py

@@ -0,0 +1,83 @@
+#!/usr/bin/env python2
+#
+# SPDX-License-Identifier: GPL-2.0
+#
+
+import sys, os, os.path, getopt
+import socket, time
+import subprocess
+import select
+
+def read(sock, n):
+    buf = ''
+    while len(buf) < n:
+        rem = n - len(buf)
+        try: s = sock.recv(rem)
+        except (socket.error), e: return ''
+        buf += s
+    return buf
+
+def send(sock, s):
+    total = len(s)
+    count = 0
+    while count < total:
+        try: n = sock.send(s)
+        except (socket.error), e: n = 0
+        if n == 0:
+            return count;
+        count += n
+    return count
+
+
+SERVER_PORT = 12877
+MAX_PORTS = 2
+
+serverPort = SERVER_PORT
+serverSocket = None
+
+HostName = socket.gethostname()
+
+# create passive socket
+serverSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+host = socket.gethostname()
+
+try: serverSocket.bind((host, 0))
+except socket.error as msg:
+    print 'bind fails: ', msg
+
+sn = serverSocket.getsockname()
+serverPort = sn[1]
+
+cmdStr = ("./tcp_client.py %d &") % (serverPort)
+os.system(cmdStr)
+
+buf = ''
+n = 0
+while n < 500:
+    buf += '.'
+    n += 1
+
+serverSocket.listen(MAX_PORTS)
+readList = [serverSocket]
+
+while True:
+    readyRead, readyWrite, inError = \
+        select.select(readList, [], [], 2)
+
+    if len(readyRead) > 0:
+        waitCount = 0
+        for sock in readyRead:
+            if sock == serverSocket:
+                (clientSocket, address) = serverSocket.accept()
+                address = str(address[0])
+                readList.append(clientSocket)
+            else:
+                sock.settimeout(1);
+                s = read(sock, 1000)
+                n = send(sock, buf)
+                sock.close()
+                serverSocket.close()
+                sys.exit(0)
+    else:
+        print 'Select timeout!'
+        sys.exit(1)

+ 16 - 0
tools/testing/selftests/bpf/test_tcpbpf.h

@@ -0,0 +1,16 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#ifndef _TEST_TCPBPF_H
+#define _TEST_TCPBPF_H
+
+struct tcpbpf_globals {
+	__u32 event_map;
+	__u32 total_retrans;
+	__u32 data_segs_in;
+	__u32 data_segs_out;
+	__u32 bad_cb_test_rv;
+	__u32 good_cb_test_rv;
+	__u64 bytes_received;
+	__u64 bytes_acked;
+};
+#endif

+ 118 - 0
tools/testing/selftests/bpf/test_tcpbpf_kern.c

@@ -0,0 +1,118 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <stddef.h>
+#include <string.h>
+#include <linux/bpf.h>
+#include <linux/if_ether.h>
+#include <linux/if_packet.h>
+#include <linux/ip.h>
+#include <linux/in6.h>
+#include <linux/types.h>
+#include <linux/socket.h>
+#include <linux/tcp.h>
+#include <netinet/in.h>
+#include "bpf_helpers.h"
+#include "bpf_endian.h"
+#include "test_tcpbpf.h"
+
+struct bpf_map_def SEC("maps") global_map = {
+	.type = BPF_MAP_TYPE_ARRAY,
+	.key_size = sizeof(__u32),
+	.value_size = sizeof(struct tcpbpf_globals),
+	.max_entries = 2,
+};
+
+static inline void update_event_map(int event)
+{
+	__u32 key = 0;
+	struct tcpbpf_globals g, *gp;
+
+	gp = bpf_map_lookup_elem(&global_map, &key);
+	if (gp == NULL) {
+		struct tcpbpf_globals g = {0};
+
+		g.event_map |= (1 << event);
+		bpf_map_update_elem(&global_map, &key, &g,
+			    BPF_ANY);
+	} else {
+		g = *gp;
+		g.event_map |= (1 << event);
+		bpf_map_update_elem(&global_map, &key, &g,
+			    BPF_ANY);
+	}
+}
+
+int _version SEC("version") = 1;
+
+SEC("sockops")
+int bpf_testcb(struct bpf_sock_ops *skops)
+{
+	int rv = -1;
+	int bad_call_rv = 0;
+	int good_call_rv = 0;
+	int op;
+	int v = 0;
+
+	op = (int) skops->op;
+
+	update_event_map(op);
+
+	switch (op) {
+	case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB:
+		/* Test failure to set largest cb flag (assumes not defined) */
+		bad_call_rv = bpf_sock_ops_cb_flags_set(skops, 0x80);
+		/* Set callback */
+		good_call_rv = bpf_sock_ops_cb_flags_set(skops,
+						 BPF_SOCK_OPS_STATE_CB_FLAG);
+		/* Update results */
+		{
+			__u32 key = 0;
+			struct tcpbpf_globals g, *gp;
+
+			gp = bpf_map_lookup_elem(&global_map, &key);
+			if (!gp)
+				break;
+			g = *gp;
+			g.bad_cb_test_rv = bad_call_rv;
+			g.good_cb_test_rv = good_call_rv;
+			bpf_map_update_elem(&global_map, &key, &g,
+					    BPF_ANY);
+		}
+		break;
+	case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB:
+		/* Set callback */
+//		good_call_rv = bpf_sock_ops_cb_flags_set(skops,
+//						 BPF_SOCK_OPS_STATE_CB_FLAG);
+		skops->sk_txhash = 0x12345f;
+		v = 0xff;
+		rv = bpf_setsockopt(skops, SOL_IPV6, IPV6_TCLASS, &v,
+				    sizeof(v));
+		break;
+	case BPF_SOCK_OPS_RTO_CB:
+		break;
+	case BPF_SOCK_OPS_RETRANS_CB:
+		break;
+	case BPF_SOCK_OPS_STATE_CB:
+		if (skops->args[1] == BPF_TCP_CLOSE) {
+			__u32 key = 0;
+			struct tcpbpf_globals g, *gp;
+
+			gp = bpf_map_lookup_elem(&global_map, &key);
+			if (!gp)
+				break;
+			g = *gp;
+			g.total_retrans = skops->total_retrans;
+			g.data_segs_in = skops->data_segs_in;
+			g.data_segs_out = skops->data_segs_out;
+			g.bytes_received = skops->bytes_received;
+			g.bytes_acked = skops->bytes_acked;
+			bpf_map_update_elem(&global_map, &key, &g,
+					    BPF_ANY);
+		}
+		break;
+	default:
+		rv = -1;
+	}
+	skops->reply = rv;
+	return 1;
+}
+char _license[] SEC("license") = "GPL";

+ 126 - 0
tools/testing/selftests/bpf/test_tcpbpf_user.c

@@ -0,0 +1,126 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <stdio.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <errno.h>
+#include <signal.h>
+#include <string.h>
+#include <assert.h>
+#include <linux/perf_event.h>
+#include <linux/ptrace.h>
+#include <linux/bpf.h>
+#include <sys/ioctl.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+#include "bpf_util.h"
+#include <linux/perf_event.h>
+#include "test_tcpbpf.h"
+
+static int bpf_find_map(const char *test, struct bpf_object *obj,
+			const char *name)
+{
+	struct bpf_map *map;
+
+	map = bpf_object__find_map_by_name(obj, name);
+	if (!map) {
+		printf("%s:FAIL:map '%s' not found\n", test, name);
+		return -1;
+	}
+	return bpf_map__fd(map);
+}
+
+#define SYSTEM(CMD)						\
+	do {							\
+		if (system(CMD)) {				\
+			printf("system(%s) FAILS!\n", CMD);	\
+		}						\
+	} while (0)
+
+int main(int argc, char **argv)
+{
+	const char *file = "test_tcpbpf_kern.o";
+	struct tcpbpf_globals g = {0};
+	int cg_fd, prog_fd, map_fd;
+	bool debug_flag = false;
+	int error = EXIT_FAILURE;
+	struct bpf_object *obj;
+	char cmd[100], *dir;
+	struct stat buffer;
+	__u32 key = 0;
+	int pid;
+	int rv;
+
+	if (argc > 1 && strcmp(argv[1], "-d") == 0)
+		debug_flag = true;
+
+	dir = "/tmp/cgroupv2/foo";
+
+	if (stat(dir, &buffer) != 0) {
+		SYSTEM("mkdir -p /tmp/cgroupv2");
+		SYSTEM("mount -t cgroup2 none /tmp/cgroupv2");
+		SYSTEM("mkdir -p /tmp/cgroupv2/foo");
+	}
+	pid = (int) getpid();
+	sprintf(cmd, "echo %d >> /tmp/cgroupv2/foo/cgroup.procs", pid);
+	SYSTEM(cmd);
+
+	cg_fd = open(dir, O_DIRECTORY, O_RDONLY);
+	if (bpf_prog_load(file, BPF_PROG_TYPE_SOCK_OPS, &obj, &prog_fd)) {
+		printf("FAILED: load_bpf_file failed for: %s\n", file);
+		goto err;
+	}
+
+	rv = bpf_prog_attach(prog_fd, cg_fd, BPF_CGROUP_SOCK_OPS, 0);
+	if (rv) {
+		printf("FAILED: bpf_prog_attach: %d (%s)\n",
+		       error, strerror(errno));
+		goto err;
+	}
+
+	SYSTEM("./tcp_server.py");
+
+	map_fd = bpf_find_map(__func__, obj, "global_map");
+	if (map_fd < 0)
+		goto err;
+
+	rv = bpf_map_lookup_elem(map_fd, &key, &g);
+	if (rv != 0) {
+		printf("FAILED: bpf_map_lookup_elem returns %d\n", rv);
+		goto err;
+	}
+
+	if (g.bytes_received != 501 || g.bytes_acked != 1002 ||
+	    g.data_segs_in != 1 || g.data_segs_out != 1 ||
+	    (g.event_map ^ 0x47e) != 0 || g.bad_cb_test_rv != 0x80 ||
+		g.good_cb_test_rv != 0) {
+		printf("FAILED: Wrong stats\n");
+		if (debug_flag) {
+			printf("\n");
+			printf("bytes_received: %d (expecting 501)\n",
+			       (int)g.bytes_received);
+			printf("bytes_acked:    %d (expecting 1002)\n",
+			       (int)g.bytes_acked);
+			printf("data_segs_in:   %d (expecting 1)\n",
+			       g.data_segs_in);
+			printf("data_segs_out:  %d (expecting 1)\n",
+			       g.data_segs_out);
+			printf("event_map:      0x%x (at least 0x47e)\n",
+			       g.event_map);
+			printf("bad_cb_test_rv: 0x%x (expecting 0x80)\n",
+			       g.bad_cb_test_rv);
+			printf("good_cb_test_rv:0x%x (expecting 0)\n",
+			       g.good_cb_test_rv);
+		}
+		goto err;
+	}
+	printf("PASSED!\n");
+	error = 0;
+err:
+	bpf_prog_detach(cg_fd, BPF_CGROUP_SOCK_OPS);
+	return error;
+
+}