Branch data Line data Source code
1 : : /* SPDX-License-Identifier: GPL-2.0 */
2 : : /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3 : :
4 : : #ifndef _LINUX_SKMSG_H
5 : : #define _LINUX_SKMSG_H
6 : :
7 : : #include <linux/bpf.h>
8 : : #include <linux/filter.h>
9 : : #include <linux/scatterlist.h>
10 : : #include <linux/skbuff.h>
11 : :
12 : : #include <net/sock.h>
13 : : #include <net/tcp.h>
14 : : #include <net/strparser.h>
15 : :
16 : : #define MAX_MSG_FRAGS MAX_SKB_FRAGS
17 : : #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
18 : :
19 : : enum __sk_action {
20 : : __SK_DROP = 0,
21 : : __SK_PASS,
22 : : __SK_REDIRECT,
23 : : __SK_NONE,
24 : : };
25 : :
26 : : struct sk_msg_sg {
27 : : u32 start;
28 : : u32 curr;
29 : : u32 end;
30 : : u32 size;
31 : : u32 copybreak;
32 : : bool copy[MAX_MSG_FRAGS];
33 : : /* The extra two elements:
34 : : * 1) used for chaining the front and sections when the list becomes
35 : : * partitioned (e.g. end < start). The crypto APIs require the
36 : : * chaining;
37 : : * 2) to chain tailer SG entries after the message.
38 : : */
39 : : struct scatterlist data[MAX_MSG_FRAGS + 2];
40 : : };
41 : :
42 : : /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
43 : : struct sk_msg {
44 : : struct sk_msg_sg sg;
45 : : void *data;
46 : : void *data_end;
47 : : u32 apply_bytes;
48 : : u32 cork_bytes;
49 : : u32 flags;
50 : : struct sk_buff *skb;
51 : : struct sock *sk_redir;
52 : : struct sock *sk;
53 : : struct list_head list;
54 : : };
55 : :
56 : : struct sk_psock_progs {
57 : : struct bpf_prog *msg_parser;
58 : : struct bpf_prog *skb_parser;
59 : : struct bpf_prog *skb_verdict;
60 : : };
61 : :
62 : : enum sk_psock_state_bits {
63 : : SK_PSOCK_TX_ENABLED,
64 : : };
65 : :
66 : : struct sk_psock_link {
67 : : struct list_head list;
68 : : struct bpf_map *map;
69 : : void *link_raw;
70 : : };
71 : :
72 : : struct sk_psock_parser {
73 : : struct strparser strp;
74 : : bool enabled;
75 : : void (*saved_data_ready)(struct sock *sk);
76 : : };
77 : :
78 : : struct sk_psock_work_state {
79 : : struct sk_buff *skb;
80 : : u32 len;
81 : : u32 off;
82 : : };
83 : :
84 : : struct sk_psock {
85 : : struct sock *sk;
86 : : struct sock *sk_redir;
87 : : u32 apply_bytes;
88 : : u32 cork_bytes;
89 : : u32 eval;
90 : : struct sk_msg *cork;
91 : : struct sk_psock_progs progs;
92 : : struct sk_psock_parser parser;
93 : : struct sk_buff_head ingress_skb;
94 : : struct list_head ingress_msg;
95 : : unsigned long state;
96 : : struct list_head link;
97 : : spinlock_t link_lock;
98 : : refcount_t refcnt;
99 : : void (*saved_unhash)(struct sock *sk);
100 : : void (*saved_close)(struct sock *sk, long timeout);
101 : : void (*saved_write_space)(struct sock *sk);
102 : : struct proto *sk_proto;
103 : : struct sk_psock_work_state work_state;
104 : : struct work_struct work;
105 : : union {
106 : : struct rcu_head rcu;
107 : : struct work_struct gc;
108 : : };
109 : : };
110 : :
111 : : int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
112 : : int elem_first_coalesce);
113 : : int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
114 : : u32 off, u32 len);
115 : : void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
116 : : int sk_msg_free(struct sock *sk, struct sk_msg *msg);
117 : : int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
118 : : void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
119 : : void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
120 : : u32 bytes);
121 : :
122 : : void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
123 : : void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
124 : :
125 : : int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
126 : : struct sk_msg *msg, u32 bytes);
127 : : int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
128 : : struct sk_msg *msg, u32 bytes);
129 : :
130 : : static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
131 : : {
132 : : WARN_ON(i == msg->sg.end && bytes);
133 : : }
134 : :
135 : : static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
136 : : {
137 : : if (psock->apply_bytes) {
138 : : if (psock->apply_bytes < bytes)
139 : : psock->apply_bytes = 0;
140 : : else
141 : : psock->apply_bytes -= bytes;
142 : : }
143 : : }
144 : :
145 : : static inline u32 sk_msg_iter_dist(u32 start, u32 end)
146 : : {
147 : 0 : return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
148 : : }
149 : :
150 : : #define sk_msg_iter_var_prev(var) \
151 : : do { \
152 : : if (var == 0) \
153 : : var = NR_MSG_FRAG_IDS - 1; \
154 : : else \
155 : : var--; \
156 : : } while (0)
157 : :
158 : : #define sk_msg_iter_var_next(var) \
159 : : do { \
160 : : var++; \
161 : : if (var == NR_MSG_FRAG_IDS) \
162 : : var = 0; \
163 : : } while (0)
164 : :
165 : : #define sk_msg_iter_prev(msg, which) \
166 : : sk_msg_iter_var_prev(msg->sg.which)
167 : :
168 : : #define sk_msg_iter_next(msg, which) \
169 : : sk_msg_iter_var_next(msg->sg.which)
170 : :
171 : : static inline void sk_msg_clear_meta(struct sk_msg *msg)
172 : : {
173 : : memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
174 : : }
175 : :
176 : : static inline void sk_msg_init(struct sk_msg *msg)
177 : : {
178 : : BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
179 : : memset(msg, 0, sizeof(*msg));
180 : : sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
181 : : }
182 : :
183 : : static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
184 : : int which, u32 size)
185 : : {
186 : : dst->sg.data[which] = src->sg.data[which];
187 : : dst->sg.data[which].length = size;
188 : : dst->sg.size += size;
189 : : src->sg.size -= size;
190 : : src->sg.data[which].length -= size;
191 : : src->sg.data[which].offset += size;
192 : : }
193 : :
194 : : static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
195 : : {
196 : : memcpy(dst, src, sizeof(*src));
197 : : sk_msg_init(src);
198 : : }
199 : :
200 : : static inline bool sk_msg_full(const struct sk_msg *msg)
201 : : {
202 : : return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
203 : : }
204 : :
205 : 0 : static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
206 : : {
207 : 0 : return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
208 : : }
209 : :
210 : 0 : static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
211 : : {
212 : 0 : return &msg->sg.data[which];
213 : : }
214 : :
215 : 0 : static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
216 : : {
217 : 0 : return msg->sg.data[which];
218 : : }
219 : :
220 : : static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
221 : : {
222 : : return sg_page(sk_msg_elem(msg, which));
223 : : }
224 : :
225 : : static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
226 : : {
227 : : return msg->flags & BPF_F_INGRESS;
228 : : }
229 : :
230 : 0 : static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
231 : : {
232 : 0 : struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
233 : :
234 : 0 : if (msg->sg.copy[msg->sg.start]) {
235 : 0 : msg->data = NULL;
236 : 0 : msg->data_end = NULL;
237 : : } else {
238 : 0 : msg->data = sg_virt(sge);
239 : 0 : msg->data_end = msg->data + sge->length;
240 : : }
241 : 0 : }
242 : :
243 : : static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
244 : : u32 len, u32 offset)
245 : : {
246 : : struct scatterlist *sge;
247 : :
248 : : get_page(page);
249 : : sge = sk_msg_elem(msg, msg->sg.end);
250 : : sg_set_page(sge, page, len, offset);
251 : : sg_unmark_end(sge);
252 : :
253 : : msg->sg.copy[msg->sg.end] = true;
254 : : msg->sg.size += len;
255 : : sk_msg_iter_next(msg, end);
256 : : }
257 : :
258 : : static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
259 : : {
260 : : do {
261 : : msg->sg.copy[i] = copy_state;
262 : : sk_msg_iter_var_next(i);
263 : : if (i == msg->sg.end)
264 : : break;
265 : : } while (1);
266 : : }
267 : :
268 : : static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
269 : : {
270 : : sk_msg_sg_copy(msg, start, true);
271 : : }
272 : :
273 : : static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
274 : : {
275 : : sk_msg_sg_copy(msg, start, false);
276 : : }
277 : :
278 : : static inline struct sk_psock *sk_psock(const struct sock *sk)
279 : : {
280 : : return rcu_dereference_sk_user_data(sk);
281 : : }
282 : :
283 : : static inline void sk_psock_queue_msg(struct sk_psock *psock,
284 : : struct sk_msg *msg)
285 : : {
286 : : list_add_tail(&msg->list, &psock->ingress_msg);
287 : : }
288 : :
289 : : static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
290 : : {
291 : : return psock ? list_empty(&psock->ingress_msg) : true;
292 : : }
293 : :
294 : : static inline void sk_psock_report_error(struct sk_psock *psock, int err)
295 : : {
296 : : struct sock *sk = psock->sk;
297 : :
298 : : sk->sk_err = err;
299 : : sk->sk_error_report(sk);
300 : : }
301 : :
302 : : struct sk_psock *sk_psock_init(struct sock *sk, int node);
303 : :
304 : : int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
305 : : void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
306 : : void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
307 : :
308 : : int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
309 : : struct sk_msg *msg);
310 : :
311 : : static inline struct sk_psock_link *sk_psock_init_link(void)
312 : : {
313 : : return kzalloc(sizeof(struct sk_psock_link),
314 : : GFP_ATOMIC | __GFP_NOWARN);
315 : : }
316 : :
317 : : static inline void sk_psock_free_link(struct sk_psock_link *link)
318 : : {
319 : : kfree(link);
320 : : }
321 : :
322 : : struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
323 : : #if defined(CONFIG_BPF_STREAM_PARSER)
324 : : void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
325 : : #else
326 : : static inline void sk_psock_unlink(struct sock *sk,
327 : : struct sk_psock_link *link)
328 : : {
329 : : }
330 : : #endif
331 : :
332 : : void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
333 : :
334 : : static inline void sk_psock_cork_free(struct sk_psock *psock)
335 : : {
336 : : if (psock->cork) {
337 : : sk_msg_free(psock->sk, psock->cork);
338 : : kfree(psock->cork);
339 : : psock->cork = NULL;
340 : : }
341 : : }
342 : :
343 : : static inline void sk_psock_update_proto(struct sock *sk,
344 : : struct sk_psock *psock,
345 : : struct proto *ops)
346 : : {
347 : : psock->saved_unhash = sk->sk_prot->unhash;
348 : : psock->saved_close = sk->sk_prot->close;
349 : : psock->saved_write_space = sk->sk_write_space;
350 : :
351 : : psock->sk_proto = sk->sk_prot;
352 : : sk->sk_prot = ops;
353 : : }
354 : :
355 : : static inline void sk_psock_restore_proto(struct sock *sk,
356 : : struct sk_psock *psock)
357 : : {
358 : : sk->sk_prot->unhash = psock->saved_unhash;
359 : :
360 : : if (psock->sk_proto) {
361 : : struct inet_connection_sock *icsk = inet_csk(sk);
362 : : bool has_ulp = !!icsk->icsk_ulp_data;
363 : :
364 : : if (has_ulp) {
365 : : tcp_update_ulp(sk, psock->sk_proto,
366 : : psock->saved_write_space);
367 : : } else {
368 : : sk->sk_prot = psock->sk_proto;
369 : : sk->sk_write_space = psock->saved_write_space;
370 : : }
371 : : psock->sk_proto = NULL;
372 : : } else {
373 : : sk->sk_write_space = psock->saved_write_space;
374 : : }
375 : : }
376 : :
377 : : static inline void sk_psock_set_state(struct sk_psock *psock,
378 : : enum sk_psock_state_bits bit)
379 : : {
380 : : set_bit(bit, &psock->state);
381 : : }
382 : :
383 : : static inline void sk_psock_clear_state(struct sk_psock *psock,
384 : : enum sk_psock_state_bits bit)
385 : : {
386 : : clear_bit(bit, &psock->state);
387 : : }
388 : :
389 : : static inline bool sk_psock_test_state(const struct sk_psock *psock,
390 : : enum sk_psock_state_bits bit)
391 : : {
392 : : return test_bit(bit, &psock->state);
393 : : }
394 : :
395 : : static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
396 : : {
397 : : struct sk_psock *psock;
398 : :
399 : : rcu_read_lock();
400 : : psock = sk_psock(sk);
401 : : if (psock) {
402 : : if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
403 : : psock = ERR_PTR(-EBUSY);
404 : : goto out;
405 : : }
406 : :
407 : : if (!refcount_inc_not_zero(&psock->refcnt))
408 : : psock = ERR_PTR(-EBUSY);
409 : : }
410 : : out:
411 : : rcu_read_unlock();
412 : : return psock;
413 : : }
414 : :
415 : : static inline struct sk_psock *sk_psock_get(struct sock *sk)
416 : : {
417 : : struct sk_psock *psock;
418 : :
419 : : rcu_read_lock();
420 : : psock = sk_psock(sk);
421 : : if (psock && !refcount_inc_not_zero(&psock->refcnt))
422 : : psock = NULL;
423 : : rcu_read_unlock();
424 : : return psock;
425 : : }
426 : :
427 : : void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
428 : : void sk_psock_destroy(struct rcu_head *rcu);
429 : : void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
430 : :
431 : : static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
432 : : {
433 : : if (refcount_dec_and_test(&psock->refcnt))
434 : : sk_psock_drop(sk, psock);
435 : : }
436 : :
437 : : static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
438 : : {
439 : : if (psock->parser.enabled)
440 : : psock->parser.saved_data_ready(sk);
441 : : else
442 : : sk->sk_data_ready(sk);
443 : : }
444 : :
445 : : static inline void psock_set_prog(struct bpf_prog **pprog,
446 : : struct bpf_prog *prog)
447 : : {
448 : : prog = xchg(pprog, prog);
449 : : if (prog)
450 : : bpf_prog_put(prog);
451 : : }
452 : :
453 : : static inline int psock_replace_prog(struct bpf_prog **pprog,
454 : : struct bpf_prog *prog,
455 : : struct bpf_prog *old)
456 : : {
457 : : if (cmpxchg(pprog, old, prog) != old)
458 : : return -ENOENT;
459 : :
460 : : if (old)
461 : : bpf_prog_put(old);
462 : :
463 : : return 0;
464 : : }
465 : :
466 : : static inline void psock_progs_drop(struct sk_psock_progs *progs)
467 : : {
468 : : psock_set_prog(&progs->msg_parser, NULL);
469 : : psock_set_prog(&progs->skb_parser, NULL);
470 : : psock_set_prog(&progs->skb_verdict, NULL);
471 : : }
472 : :
473 : : int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
474 : :
475 : : static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
476 : : {
477 : : if (!psock)
478 : : return false;
479 : : return psock->parser.enabled;
480 : : }
481 : : #endif /* _LINUX_SKMSG_H */
|