Branch data Line data Source code
1 : : /* SPDX-License-Identifier: GPL-2.0 */
2 : : /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3 : :
4 : : #ifndef _LINUX_SKMSG_H
5 : : #define _LINUX_SKMSG_H
6 : :
7 : : #include <linux/bpf.h>
8 : : #include <linux/filter.h>
9 : : #include <linux/scatterlist.h>
10 : : #include <linux/skbuff.h>
11 : :
12 : : #include <net/sock.h>
13 : : #include <net/tcp.h>
14 : : #include <net/strparser.h>
15 : :
16 : : #define MAX_MSG_FRAGS MAX_SKB_FRAGS
17 : : #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
18 : :
19 : : enum __sk_action {
20 : : __SK_DROP = 0,
21 : : __SK_PASS,
22 : : __SK_REDIRECT,
23 : : __SK_NONE,
24 : : };
25 : :
26 : : struct sk_msg_sg {
27 : : u32 start;
28 : : u32 curr;
29 : : u32 end;
30 : : u32 size;
31 : : u32 copybreak;
32 : : unsigned long copy;
33 : : /* The extra two elements:
34 : : * 1) used for chaining the front and sections when the list becomes
35 : : * partitioned (e.g. end < start). The crypto APIs require the
36 : : * chaining;
37 : : * 2) to chain tailer SG entries after the message.
38 : : */
39 : : struct scatterlist data[MAX_MSG_FRAGS + 2];
40 : : };
41 : : static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
42 : :
43 : : /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
44 : : struct sk_msg {
45 : : struct sk_msg_sg sg;
46 : : void *data;
47 : : void *data_end;
48 : : u32 apply_bytes;
49 : : u32 cork_bytes;
50 : : u32 flags;
51 : : struct sk_buff *skb;
52 : : struct sock *sk_redir;
53 : : struct sock *sk;
54 : : struct list_head list;
55 : : };
56 : :
57 : : struct sk_psock_progs {
58 : : struct bpf_prog *msg_parser;
59 : : struct bpf_prog *skb_parser;
60 : : struct bpf_prog *skb_verdict;
61 : : };
62 : :
63 : : enum sk_psock_state_bits {
64 : : SK_PSOCK_TX_ENABLED,
65 : : };
66 : :
67 : : struct sk_psock_link {
68 : : struct list_head list;
69 : : struct bpf_map *map;
70 : : void *link_raw;
71 : : };
72 : :
73 : : struct sk_psock_parser {
74 : : struct strparser strp;
75 : : bool enabled;
76 : : void (*saved_data_ready)(struct sock *sk);
77 : : };
78 : :
79 : : struct sk_psock_work_state {
80 : : struct sk_buff *skb;
81 : : u32 len;
82 : : u32 off;
83 : : };
84 : :
85 : : struct sk_psock {
86 : : struct sock *sk;
87 : : struct sock *sk_redir;
88 : : u32 apply_bytes;
89 : : u32 cork_bytes;
90 : : u32 eval;
91 : : struct sk_msg *cork;
92 : : struct sk_psock_progs progs;
93 : : struct sk_psock_parser parser;
94 : : struct sk_buff_head ingress_skb;
95 : : struct list_head ingress_msg;
96 : : unsigned long state;
97 : : struct list_head link;
98 : : spinlock_t link_lock;
99 : : refcount_t refcnt;
100 : : void (*saved_unhash)(struct sock *sk);
101 : : void (*saved_close)(struct sock *sk, long timeout);
102 : : void (*saved_write_space)(struct sock *sk);
103 : : struct proto *sk_proto;
104 : : struct sk_psock_work_state work_state;
105 : : struct work_struct work;
106 : : union {
107 : : struct rcu_head rcu;
108 : : struct work_struct gc;
109 : : };
110 : : };
111 : :
112 : : int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
113 : : int elem_first_coalesce);
114 : : int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
115 : : u32 off, u32 len);
116 : : void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
117 : : int sk_msg_free(struct sock *sk, struct sk_msg *msg);
118 : : int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
119 : : void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
120 : : void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
121 : : u32 bytes);
122 : :
123 : : void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
124 : : void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
125 : :
126 : : int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
127 : : struct sk_msg *msg, u32 bytes);
128 : : int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
129 : : struct sk_msg *msg, u32 bytes);
130 : :
131 : : static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
132 : : {
133 : : WARN_ON(i == msg->sg.end && bytes);
134 : : }
135 : :
136 : : static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
137 : : {
138 : : if (psock->apply_bytes) {
139 : : if (psock->apply_bytes < bytes)
140 : : psock->apply_bytes = 0;
141 : : else
142 : : psock->apply_bytes -= bytes;
143 : : }
144 : : }
145 : :
146 : 0 : static inline u32 sk_msg_iter_dist(u32 start, u32 end)
147 : : {
148 : 0 : return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
149 : : }
150 : :
151 : : #define sk_msg_iter_var_prev(var) \
152 : : do { \
153 : : if (var == 0) \
154 : : var = NR_MSG_FRAG_IDS - 1; \
155 : : else \
156 : : var--; \
157 : : } while (0)
158 : :
159 : : #define sk_msg_iter_var_next(var) \
160 : : do { \
161 : : var++; \
162 : : if (var == NR_MSG_FRAG_IDS) \
163 : : var = 0; \
164 : : } while (0)
165 : :
166 : : #define sk_msg_iter_prev(msg, which) \
167 : : sk_msg_iter_var_prev(msg->sg.which)
168 : :
169 : : #define sk_msg_iter_next(msg, which) \
170 : : sk_msg_iter_var_next(msg->sg.which)
171 : :
172 : : static inline void sk_msg_clear_meta(struct sk_msg *msg)
173 : : {
174 : : memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
175 : : }
176 : :
177 : : static inline void sk_msg_init(struct sk_msg *msg)
178 : : {
179 : : BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
180 : : memset(msg, 0, sizeof(*msg));
181 : : sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
182 : : }
183 : :
184 : : static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
185 : : int which, u32 size)
186 : : {
187 : : dst->sg.data[which] = src->sg.data[which];
188 : : dst->sg.data[which].length = size;
189 : : dst->sg.size += size;
190 : : src->sg.data[which].length -= size;
191 : : src->sg.data[which].offset += size;
192 : : }
193 : :
194 : : static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
195 : : {
196 : : memcpy(dst, src, sizeof(*src));
197 : : sk_msg_init(src);
198 : : }
199 : :
200 : : static inline bool sk_msg_full(const struct sk_msg *msg)
201 : : {
202 : : return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
203 : : }
204 : :
205 : 0 : static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
206 : : {
207 [ # # # # ]: 0 : return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
208 : : }
209 : :
210 : 0 : static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
211 : : {
212 [ # # # # : 0 : return &msg->sg.data[which];
# # # # #
# # # # #
# # ]
213 : : }
214 : :
215 : 0 : static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
216 : : {
217 [ # # # # : 0 : return msg->sg.data[which];
# # # # ]
218 : : }
219 : :
220 : : static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
221 : : {
222 : : return sg_page(sk_msg_elem(msg, which));
223 : : }
224 : :
225 : : static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
226 : : {
227 : : return msg->flags & BPF_F_INGRESS;
228 : : }
229 : :
230 : 0 : static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
231 : : {
232 : 0 : struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
233 : :
234 [ # # ]: 0 : if (test_bit(msg->sg.start, &msg->sg.copy)) {
235 : 0 : msg->data = NULL;
236 : 0 : msg->data_end = NULL;
237 : : } else {
238 : 0 : msg->data = sg_virt(sge);
239 : 0 : msg->data_end = msg->data + sge->length;
240 : : }
241 : 0 : }
242 : :
243 : : static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
244 : : u32 len, u32 offset)
245 : : {
246 : : struct scatterlist *sge;
247 : :
248 : : get_page(page);
249 : : sge = sk_msg_elem(msg, msg->sg.end);
250 : : sg_set_page(sge, page, len, offset);
251 : : sg_unmark_end(sge);
252 : :
253 : : __set_bit(msg->sg.end, &msg->sg.copy);
254 : : msg->sg.size += len;
255 : : sk_msg_iter_next(msg, end);
256 : : }
257 : :
258 : : static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
259 : : {
260 : : do {
261 : : if (copy_state)
262 : : __set_bit(i, &msg->sg.copy);
263 : : else
264 : : __clear_bit(i, &msg->sg.copy);
265 : : sk_msg_iter_var_next(i);
266 : : if (i == msg->sg.end)
267 : : break;
268 : : } while (1);
269 : : }
270 : :
271 : : static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
272 : : {
273 : : sk_msg_sg_copy(msg, start, true);
274 : : }
275 : :
276 : : static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
277 : : {
278 : : sk_msg_sg_copy(msg, start, false);
279 : : }
280 : :
281 : : static inline struct sk_psock *sk_psock(const struct sock *sk)
282 : : {
283 : : return rcu_dereference_sk_user_data(sk);
284 : : }
285 : :
286 : : static inline void sk_psock_queue_msg(struct sk_psock *psock,
287 : : struct sk_msg *msg)
288 : : {
289 : : list_add_tail(&msg->list, &psock->ingress_msg);
290 : : }
291 : :
292 : : static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
293 : : {
294 : : return psock ? list_empty(&psock->ingress_msg) : true;
295 : : }
296 : :
297 : : static inline void sk_psock_report_error(struct sk_psock *psock, int err)
298 : : {
299 : : struct sock *sk = psock->sk;
300 : :
301 : : sk->sk_err = err;
302 : : sk->sk_error_report(sk);
303 : : }
304 : :
305 : : struct sk_psock *sk_psock_init(struct sock *sk, int node);
306 : :
307 : : int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
308 : : void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
309 : : void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
310 : :
311 : : int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
312 : : struct sk_msg *msg);
313 : :
314 : : static inline struct sk_psock_link *sk_psock_init_link(void)
315 : : {
316 : : return kzalloc(sizeof(struct sk_psock_link),
317 : : GFP_ATOMIC | __GFP_NOWARN);
318 : : }
319 : :
320 : : static inline void sk_psock_free_link(struct sk_psock_link *link)
321 : : {
322 : : kfree(link);
323 : : }
324 : :
325 : : struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
326 : : #if defined(CONFIG_BPF_STREAM_PARSER)
327 : : void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
328 : : #else
329 : : static inline void sk_psock_unlink(struct sock *sk,
330 : : struct sk_psock_link *link)
331 : : {
332 : : }
333 : : #endif
334 : :
335 : : void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
336 : :
337 : : static inline void sk_psock_cork_free(struct sk_psock *psock)
338 : : {
339 : : if (psock->cork) {
340 : : sk_msg_free(psock->sk, psock->cork);
341 : : kfree(psock->cork);
342 : : psock->cork = NULL;
343 : : }
344 : : }
345 : :
346 : : static inline void sk_psock_update_proto(struct sock *sk,
347 : : struct sk_psock *psock,
348 : : struct proto *ops)
349 : : {
350 : : psock->saved_unhash = sk->sk_prot->unhash;
351 : : psock->saved_close = sk->sk_prot->close;
352 : : psock->saved_write_space = sk->sk_write_space;
353 : :
354 : : psock->sk_proto = sk->sk_prot;
355 : : sk->sk_prot = ops;
356 : : }
357 : :
358 : : static inline void sk_psock_restore_proto(struct sock *sk,
359 : : struct sk_psock *psock)
360 : : {
361 : : sk->sk_prot->unhash = psock->saved_unhash;
362 : :
363 : : if (psock->sk_proto) {
364 : : struct inet_connection_sock *icsk = inet_csk(sk);
365 : : bool has_ulp = !!icsk->icsk_ulp_data;
366 : :
367 : : if (has_ulp) {
368 : : tcp_update_ulp(sk, psock->sk_proto,
369 : : psock->saved_write_space);
370 : : } else {
371 : : sk->sk_prot = psock->sk_proto;
372 : : sk->sk_write_space = psock->saved_write_space;
373 : : }
374 : : psock->sk_proto = NULL;
375 : : } else {
376 : : sk->sk_write_space = psock->saved_write_space;
377 : : }
378 : : }
379 : :
380 : : static inline void sk_psock_set_state(struct sk_psock *psock,
381 : : enum sk_psock_state_bits bit)
382 : : {
383 : : set_bit(bit, &psock->state);
384 : : }
385 : :
386 : : static inline void sk_psock_clear_state(struct sk_psock *psock,
387 : : enum sk_psock_state_bits bit)
388 : : {
389 : : clear_bit(bit, &psock->state);
390 : : }
391 : :
392 : : static inline bool sk_psock_test_state(const struct sk_psock *psock,
393 : : enum sk_psock_state_bits bit)
394 : : {
395 : : return test_bit(bit, &psock->state);
396 : : }
397 : :
398 : : static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
399 : : {
400 : : struct sk_psock *psock;
401 : :
402 : : rcu_read_lock();
403 : : psock = sk_psock(sk);
404 : : if (psock) {
405 : : if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
406 : : psock = ERR_PTR(-EBUSY);
407 : : goto out;
408 : : }
409 : :
410 : : if (!refcount_inc_not_zero(&psock->refcnt))
411 : : psock = ERR_PTR(-EBUSY);
412 : : }
413 : : out:
414 : : rcu_read_unlock();
415 : : return psock;
416 : : }
417 : :
418 : : static inline struct sk_psock *sk_psock_get(struct sock *sk)
419 : : {
420 : : struct sk_psock *psock;
421 : :
422 : : rcu_read_lock();
423 : : psock = sk_psock(sk);
424 : : if (psock && !refcount_inc_not_zero(&psock->refcnt))
425 : : psock = NULL;
426 : : rcu_read_unlock();
427 : : return psock;
428 : : }
429 : :
430 : : void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
431 : : void sk_psock_destroy(struct rcu_head *rcu);
432 : : void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
433 : :
434 : : static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
435 : : {
436 : : if (refcount_dec_and_test(&psock->refcnt))
437 : : sk_psock_drop(sk, psock);
438 : : }
439 : :
440 : : static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
441 : : {
442 : : if (psock->parser.enabled)
443 : : psock->parser.saved_data_ready(sk);
444 : : else
445 : : sk->sk_data_ready(sk);
446 : : }
447 : :
448 : : static inline void psock_set_prog(struct bpf_prog **pprog,
449 : : struct bpf_prog *prog)
450 : : {
451 : : prog = xchg(pprog, prog);
452 : : if (prog)
453 : : bpf_prog_put(prog);
454 : : }
455 : :
456 : : static inline void psock_progs_drop(struct sk_psock_progs *progs)
457 : : {
458 : : psock_set_prog(&progs->msg_parser, NULL);
459 : : psock_set_prog(&progs->skb_parser, NULL);
460 : : psock_set_prog(&progs->skb_verdict, NULL);
461 : : }
462 : :
463 : : #endif /* _LINUX_SKMSG_H */
|