|
@@ -270,11 +270,6 @@ static inline struct sk_psock *sk_psock(const struct sock *sk)
|
|
|
return rcu_dereference_sk_user_data(sk);
|
|
|
}
|
|
|
|
|
|
-static inline bool sk_has_psock(struct sock *sk)
|
|
|
-{
|
|
|
- return sk_psock(sk) != NULL && sk->sk_prot->recvmsg == tcp_bpf_recvmsg;
|
|
|
-}
|
|
|
-
|
|
|
static inline void sk_psock_queue_msg(struct sk_psock *psock,
|
|
|
struct sk_msg *msg)
|
|
|
{
|
|
@@ -374,6 +369,26 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock,
|
|
|
return test_bit(bit, &psock->state);
|
|
|
}
|
|
|
|
|
|
+static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
|
|
|
+{
|
|
|
+ struct sk_psock *psock;
|
|
|
+
|
|
|
+ rcu_read_lock();
|
|
|
+ psock = sk_psock(sk);
|
|
|
+ if (psock) {
|
|
|
+ if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
|
|
|
+ psock = ERR_PTR(-EBUSY);
|
|
|
+ goto out;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!refcount_inc_not_zero(&psock->refcnt))
|
|
|
+ psock = ERR_PTR(-EBUSY);
|
|
|
+ }
|
|
|
+out:
|
|
|
+ rcu_read_unlock();
|
|
|
+ return psock;
|
|
|
+}
|
|
|
+
|
|
|
static inline struct sk_psock *sk_psock_get(struct sock *sk)
|
|
|
{
|
|
|
struct sk_psock *psock;
|