test_sock_addr.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. // SPDX-License-Identifier: GPL-2.0
  2. // Copyright (c) 2018 Facebook
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <unistd.h>
  6. #include <arpa/inet.h>
  7. #include <sys/types.h>
  8. #include <sys/socket.h>
  9. #include <linux/filter.h>
  10. #include <bpf/bpf.h>
  11. #include <bpf/libbpf.h>
  12. #include "cgroup_helpers.h"
  13. #include "bpf_rlimit.h"
  14. #define CG_PATH "/foo"
  15. #define CONNECT4_PROG_PATH "./connect4_prog.o"
  16. #define CONNECT6_PROG_PATH "./connect6_prog.o"
  17. #define SERV4_IP "192.168.1.254"
  18. #define SERV4_REWRITE_IP "127.0.0.1"
  19. #define SERV4_PORT 4040
  20. #define SERV4_REWRITE_PORT 4444
  21. #define SERV6_IP "face:b00c:1234:5678::abcd"
  22. #define SERV6_REWRITE_IP "::1"
  23. #define SERV6_PORT 6060
  24. #define SERV6_REWRITE_PORT 6666
  25. #define INET_NTOP_BUF 40
  26. typedef int (*load_fn)(enum bpf_attach_type, const char *comment);
  27. typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
  28. struct program {
  29. enum bpf_attach_type type;
  30. load_fn loadfn;
  31. int fd;
  32. const char *name;
  33. enum bpf_attach_type invalid_type;
  34. };
  35. char bpf_log_buf[BPF_LOG_BUF_SIZE];
  36. static int mk_sockaddr(int domain, const char *ip, unsigned short port,
  37. struct sockaddr *addr, socklen_t addr_len)
  38. {
  39. struct sockaddr_in6 *addr6;
  40. struct sockaddr_in *addr4;
  41. if (domain != AF_INET && domain != AF_INET6) {
  42. log_err("Unsupported address family");
  43. return -1;
  44. }
  45. memset(addr, 0, addr_len);
  46. if (domain == AF_INET) {
  47. if (addr_len < sizeof(struct sockaddr_in))
  48. return -1;
  49. addr4 = (struct sockaddr_in *)addr;
  50. addr4->sin_family = domain;
  51. addr4->sin_port = htons(port);
  52. if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) {
  53. log_err("Invalid IPv4: %s", ip);
  54. return -1;
  55. }
  56. } else if (domain == AF_INET6) {
  57. if (addr_len < sizeof(struct sockaddr_in6))
  58. return -1;
  59. addr6 = (struct sockaddr_in6 *)addr;
  60. addr6->sin6_family = domain;
  61. addr6->sin6_port = htons(port);
  62. if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) {
  63. log_err("Invalid IPv6: %s", ip);
  64. return -1;
  65. }
  66. }
  67. return 0;
  68. }
  69. static int load_insns(enum bpf_attach_type attach_type,
  70. const struct bpf_insn *insns, size_t insns_cnt,
  71. const char *comment)
  72. {
  73. struct bpf_load_program_attr load_attr;
  74. int ret;
  75. memset(&load_attr, 0, sizeof(struct bpf_load_program_attr));
  76. load_attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
  77. load_attr.expected_attach_type = attach_type;
  78. load_attr.insns = insns;
  79. load_attr.insns_cnt = insns_cnt;
  80. load_attr.license = "GPL";
  81. ret = bpf_load_program_xattr(&load_attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
  82. if (ret < 0 && comment) {
  83. log_err(">>> Loading %s program error.\n"
  84. ">>> Output from verifier:\n%s\n-------\n",
  85. comment, bpf_log_buf);
  86. }
  87. return ret;
  88. }
  89. /* [1] These testing programs try to read different context fields, including
  90. * narrow loads of different sizes from user_ip4 and user_ip6, and write to
  91. * those allowed to be overridden.
  92. *
  93. * [2] BPF_LD_IMM64 & BPF_JMP_REG are used below whenever there is a need to
  94. * compare a register with unsigned 32bit integer. BPF_JMP_IMM can't be used
  95. * in such cases since it accepts only _signed_ 32bit integer as IMM
  96. * argument. Also note that BPF_LD_IMM64 contains 2 instructions what matters
  97. * to count jumps properly.
  98. */
  99. static int bind4_prog_load(enum bpf_attach_type attach_type,
  100. const char *comment)
  101. {
  102. union {
  103. uint8_t u4_addr8[4];
  104. uint16_t u4_addr16[2];
  105. uint32_t u4_addr32;
  106. } ip4;
  107. struct sockaddr_in addr4_rw;
  108. if (inet_pton(AF_INET, SERV4_IP, (void *)&ip4) != 1) {
  109. log_err("Invalid IPv4: %s", SERV4_IP);
  110. return -1;
  111. }
  112. if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT,
  113. (struct sockaddr *)&addr4_rw, sizeof(addr4_rw)) == -1)
  114. return -1;
  115. /* See [1]. */
  116. struct bpf_insn insns[] = {
  117. BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
  118. /* if (sk.family == AF_INET && */
  119. BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
  120. offsetof(struct bpf_sock_addr, family)),
  121. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 16),
  122. /* (sk.type == SOCK_DGRAM || sk.type == SOCK_STREAM) && */
  123. BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
  124. offsetof(struct bpf_sock_addr, type)),
  125. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 1),
  126. BPF_JMP_A(1),
  127. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_STREAM, 12),
  128. /* 1st_byte_of_user_ip4 == expected && */
  129. BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
  130. offsetof(struct bpf_sock_addr, user_ip4)),
  131. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[0], 10),
  132. /* 1st_half_of_user_ip4 == expected && */
  133. BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
  134. offsetof(struct bpf_sock_addr, user_ip4)),
  135. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[0], 8),
  136. /* whole_user_ip4 == expected) { */
  137. BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
  138. offsetof(struct bpf_sock_addr, user_ip4)),
  139. BPF_LD_IMM64(BPF_REG_8, ip4.u4_addr32), /* See [2]. */
  140. BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 4),
  141. /* user_ip4 = addr4_rw.sin_addr */
  142. BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_addr.s_addr),
  143. BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
  144. offsetof(struct bpf_sock_addr, user_ip4)),
  145. /* user_port = addr4_rw.sin_port */
  146. BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_port),
  147. BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
  148. offsetof(struct bpf_sock_addr, user_port)),
  149. /* } */
  150. /* return 1 */
  151. BPF_MOV64_IMM(BPF_REG_0, 1),
  152. BPF_EXIT_INSN(),
  153. };
  154. return load_insns(attach_type, insns,
  155. sizeof(insns) / sizeof(struct bpf_insn), comment);
  156. }
  157. static int bind6_prog_load(enum bpf_attach_type attach_type,
  158. const char *comment)
  159. {
  160. struct sockaddr_in6 addr6_rw;
  161. struct in6_addr ip6;
  162. if (inet_pton(AF_INET6, SERV6_IP, (void *)&ip6) != 1) {
  163. log_err("Invalid IPv6: %s", SERV6_IP);
  164. return -1;
  165. }
  166. if (mk_sockaddr(AF_INET6, SERV6_REWRITE_IP, SERV6_REWRITE_PORT,
  167. (struct sockaddr *)&addr6_rw, sizeof(addr6_rw)) == -1)
  168. return -1;
  169. /* See [1]. */
  170. struct bpf_insn insns[] = {
  171. BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
  172. /* if (sk.family == AF_INET6 && */
  173. BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
  174. offsetof(struct bpf_sock_addr, family)),
  175. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18),
  176. /* 5th_byte_of_user_ip6 == expected && */
  177. BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
  178. offsetof(struct bpf_sock_addr, user_ip6[1])),
  179. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr[4], 16),
  180. /* 3rd_half_of_user_ip6 == expected && */
  181. BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
  182. offsetof(struct bpf_sock_addr, user_ip6[1])),
  183. BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr16[2], 14),
  184. /* last_word_of_user_ip6 == expected) { */
  185. BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
  186. offsetof(struct bpf_sock_addr, user_ip6[3])),
  187. BPF_LD_IMM64(BPF_REG_8, ip6.s6_addr32[3]), /* See [2]. */
  188. BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 10),
  189. #define STORE_IPV6_WORD(N) \
  190. BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_addr.s6_addr32[N]), \
  191. BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, \
  192. offsetof(struct bpf_sock_addr, user_ip6[N]))
  193. /* user_ip6 = addr6_rw.sin6_addr */
  194. STORE_IPV6_WORD(0),
  195. STORE_IPV6_WORD(1),
  196. STORE_IPV6_WORD(2),
  197. STORE_IPV6_WORD(3),
  198. /* user_port = addr6_rw.sin6_port */
  199. BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_port),
  200. BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
  201. offsetof(struct bpf_sock_addr, user_port)),
  202. /* } */
  203. /* return 1 */
  204. BPF_MOV64_IMM(BPF_REG_0, 1),
  205. BPF_EXIT_INSN(),
  206. };
  207. return load_insns(attach_type, insns,
  208. sizeof(insns) / sizeof(struct bpf_insn), comment);
  209. }
  210. static int connect_prog_load_path(const char *path,
  211. enum bpf_attach_type attach_type,
  212. const char *comment)
  213. {
  214. struct bpf_prog_load_attr attr;
  215. struct bpf_object *obj;
  216. int prog_fd;
  217. memset(&attr, 0, sizeof(struct bpf_prog_load_attr));
  218. attr.file = path;
  219. attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
  220. attr.expected_attach_type = attach_type;
  221. if (bpf_prog_load_xattr(&attr, &obj, &prog_fd)) {
  222. if (comment)
  223. log_err(">>> Loading %s program at %s error.\n",
  224. comment, path);
  225. return -1;
  226. }
  227. return prog_fd;
  228. }
  229. static int connect4_prog_load(enum bpf_attach_type attach_type,
  230. const char *comment)
  231. {
  232. return connect_prog_load_path(CONNECT4_PROG_PATH, attach_type, comment);
  233. }
  234. static int connect6_prog_load(enum bpf_attach_type attach_type,
  235. const char *comment)
  236. {
  237. return connect_prog_load_path(CONNECT6_PROG_PATH, attach_type, comment);
  238. }
  239. static void print_ip_port(int sockfd, info_fn fn, const char *fmt)
  240. {
  241. char addr_buf[INET_NTOP_BUF];
  242. struct sockaddr_storage addr;
  243. struct sockaddr_in6 *addr6;
  244. struct sockaddr_in *addr4;
  245. socklen_t addr_len;
  246. unsigned short port;
  247. void *nip;
  248. addr_len = sizeof(struct sockaddr_storage);
  249. memset(&addr, 0, addr_len);
  250. if (fn(sockfd, (struct sockaddr *)&addr, (socklen_t *)&addr_len) == 0) {
  251. if (addr.ss_family == AF_INET) {
  252. addr4 = (struct sockaddr_in *)&addr;
  253. nip = (void *)&addr4->sin_addr;
  254. port = ntohs(addr4->sin_port);
  255. } else if (addr.ss_family == AF_INET6) {
  256. addr6 = (struct sockaddr_in6 *)&addr;
  257. nip = (void *)&addr6->sin6_addr;
  258. port = ntohs(addr6->sin6_port);
  259. } else {
  260. return;
  261. }
  262. const char *addr_str =
  263. inet_ntop(addr.ss_family, nip, addr_buf, INET_NTOP_BUF);
  264. printf(fmt, addr_str ? addr_str : "??", port);
  265. }
  266. }
  267. static void print_local_ip_port(int sockfd, const char *fmt)
  268. {
  269. print_ip_port(sockfd, getsockname, fmt);
  270. }
  271. static void print_remote_ip_port(int sockfd, const char *fmt)
  272. {
  273. print_ip_port(sockfd, getpeername, fmt);
  274. }
  275. static int start_server(int type, const struct sockaddr_storage *addr,
  276. socklen_t addr_len)
  277. {
  278. int fd;
  279. fd = socket(addr->ss_family, type, 0);
  280. if (fd == -1) {
  281. log_err("Failed to create server socket");
  282. goto out;
  283. }
  284. if (bind(fd, (const struct sockaddr *)addr, addr_len) == -1) {
  285. log_err("Failed to bind server socket");
  286. goto close_out;
  287. }
  288. if (type == SOCK_STREAM) {
  289. if (listen(fd, 128) == -1) {
  290. log_err("Failed to listen on server socket");
  291. goto close_out;
  292. }
  293. }
  294. print_local_ip_port(fd, "\t Actual: bind(%s, %d)\n");
  295. goto out;
  296. close_out:
  297. close(fd);
  298. fd = -1;
  299. out:
  300. return fd;
  301. }
  302. static int connect_to_server(int type, const struct sockaddr_storage *addr,
  303. socklen_t addr_len)
  304. {
  305. int domain;
  306. int fd;
  307. domain = addr->ss_family;
  308. if (domain != AF_INET && domain != AF_INET6) {
  309. log_err("Unsupported address family");
  310. return -1;
  311. }
  312. fd = socket(domain, type, 0);
  313. if (fd == -1) {
  314. log_err("Failed to creating client socket");
  315. return -1;
  316. }
  317. if (connect(fd, (const struct sockaddr *)addr, addr_len) == -1) {
  318. log_err("Fail to connect to server");
  319. goto err;
  320. }
  321. print_remote_ip_port(fd, "\t Actual: connect(%s, %d)");
  322. print_local_ip_port(fd, " from (%s, %d)\n");
  323. return 0;
  324. err:
  325. close(fd);
  326. return -1;
  327. }
  328. static void print_test_case_num(int domain, int type)
  329. {
  330. static int test_num;
  331. printf("Test case #%d (%s/%s):\n", ++test_num,
  332. (domain == AF_INET ? "IPv4" :
  333. domain == AF_INET6 ? "IPv6" :
  334. "unknown_domain"),
  335. (type == SOCK_STREAM ? "TCP" :
  336. type == SOCK_DGRAM ? "UDP" :
  337. "unknown_type"));
  338. }
  339. static int run_test_case(int domain, int type, const char *ip,
  340. unsigned short port)
  341. {
  342. struct sockaddr_storage addr;
  343. socklen_t addr_len = sizeof(addr);
  344. int servfd = -1;
  345. int err = 0;
  346. print_test_case_num(domain, type);
  347. if (mk_sockaddr(domain, ip, port, (struct sockaddr *)&addr,
  348. addr_len) == -1)
  349. return -1;
  350. printf("\tRequested: bind(%s, %d) ..\n", ip, port);
  351. servfd = start_server(type, &addr, addr_len);
  352. if (servfd == -1)
  353. goto err;
  354. printf("\tRequested: connect(%s, %d) from (*, *) ..\n", ip, port);
  355. if (connect_to_server(type, &addr, addr_len))
  356. goto err;
  357. goto out;
  358. err:
  359. err = -1;
  360. out:
  361. close(servfd);
  362. return err;
  363. }
  364. static void close_progs_fds(struct program *progs, size_t prog_cnt)
  365. {
  366. size_t i;
  367. for (i = 0; i < prog_cnt; ++i) {
  368. close(progs[i].fd);
  369. progs[i].fd = -1;
  370. }
  371. }
  372. static int load_and_attach_progs(int cgfd, struct program *progs,
  373. size_t prog_cnt)
  374. {
  375. size_t i;
  376. for (i = 0; i < prog_cnt; ++i) {
  377. printf("Load %s with invalid type (can pollute stderr) ",
  378. progs[i].name);
  379. fflush(stdout);
  380. progs[i].fd = progs[i].loadfn(progs[i].invalid_type, NULL);
  381. if (progs[i].fd != -1) {
  382. log_err("Load with invalid type accepted for %s",
  383. progs[i].name);
  384. goto err;
  385. }
  386. printf("... REJECTED\n");
  387. printf("Load %s with valid type", progs[i].name);
  388. progs[i].fd = progs[i].loadfn(progs[i].type, progs[i].name);
  389. if (progs[i].fd == -1) {
  390. log_err("Failed to load program %s", progs[i].name);
  391. goto err;
  392. }
  393. printf(" ... OK\n");
  394. printf("Attach %s with invalid type", progs[i].name);
  395. if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].invalid_type,
  396. BPF_F_ALLOW_OVERRIDE) != -1) {
  397. log_err("Attach with invalid type accepted for %s",
  398. progs[i].name);
  399. goto err;
  400. }
  401. printf(" ... REJECTED\n");
  402. printf("Attach %s with valid type", progs[i].name);
  403. if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].type,
  404. BPF_F_ALLOW_OVERRIDE) == -1) {
  405. log_err("Failed to attach program %s", progs[i].name);
  406. goto err;
  407. }
  408. printf(" ... OK\n");
  409. }
  410. return 0;
  411. err:
  412. close_progs_fds(progs, prog_cnt);
  413. return -1;
  414. }
  415. static int run_domain_test(int domain, int cgfd, struct program *progs,
  416. size_t prog_cnt, const char *ip, unsigned short port)
  417. {
  418. int err = 0;
  419. if (load_and_attach_progs(cgfd, progs, prog_cnt) == -1)
  420. goto err;
  421. if (run_test_case(domain, SOCK_STREAM, ip, port) == -1)
  422. goto err;
  423. if (run_test_case(domain, SOCK_DGRAM, ip, port) == -1)
  424. goto err;
  425. goto out;
  426. err:
  427. err = -1;
  428. out:
  429. close_progs_fds(progs, prog_cnt);
  430. return err;
  431. }
  432. static int run_test(void)
  433. {
  434. size_t inet6_prog_cnt;
  435. size_t inet_prog_cnt;
  436. int cgfd = -1;
  437. int err = 0;
  438. struct program inet6_progs[] = {
  439. {BPF_CGROUP_INET6_BIND, bind6_prog_load, -1, "bind6",
  440. BPF_CGROUP_INET4_BIND},
  441. {BPF_CGROUP_INET6_CONNECT, connect6_prog_load, -1, "connect6",
  442. BPF_CGROUP_INET4_CONNECT},
  443. };
  444. inet6_prog_cnt = sizeof(inet6_progs) / sizeof(struct program);
  445. struct program inet_progs[] = {
  446. {BPF_CGROUP_INET4_BIND, bind4_prog_load, -1, "bind4",
  447. BPF_CGROUP_INET6_BIND},
  448. {BPF_CGROUP_INET4_CONNECT, connect4_prog_load, -1, "connect4",
  449. BPF_CGROUP_INET6_CONNECT},
  450. };
  451. inet_prog_cnt = sizeof(inet_progs) / sizeof(struct program);
  452. if (setup_cgroup_environment())
  453. goto err;
  454. cgfd = create_and_get_cgroup(CG_PATH);
  455. if (!cgfd)
  456. goto err;
  457. if (join_cgroup(CG_PATH))
  458. goto err;
  459. if (run_domain_test(AF_INET, cgfd, inet_progs, inet_prog_cnt, SERV4_IP,
  460. SERV4_PORT) == -1)
  461. goto err;
  462. if (run_domain_test(AF_INET6, cgfd, inet6_progs, inet6_prog_cnt,
  463. SERV6_IP, SERV6_PORT) == -1)
  464. goto err;
  465. goto out;
  466. err:
  467. err = -1;
  468. out:
  469. close(cgfd);
  470. cleanup_cgroup_environment();
  471. printf(err ? "### FAIL\n" : "### SUCCESS\n");
  472. return err;
  473. }
  474. int main(int argc, char **argv)
  475. {
  476. if (argc < 2) {
  477. fprintf(stderr,
  478. "%s has to be run via %s.sh. Skip direct run.\n",
  479. argv[0], argv[0]);
  480. exit(0);
  481. }
  482. return run_test();
  483. }