aes-neonbs-glue.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. /*
  2. * Bit sliced AES using NEON instructions
  3. *
  4. * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
  5. *
  6. * This program is free software; you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License version 2 as
  8. * published by the Free Software Foundation.
  9. */
  10. #include <asm/neon.h>
  11. #include <crypto/aes.h>
  12. #include <crypto/cbc.h>
  13. #include <crypto/internal/simd.h>
  14. #include <crypto/internal/skcipher.h>
  15. #include <crypto/xts.h>
  16. #include <linux/module.h>
  17. MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  18. MODULE_LICENSE("GPL v2");
  19. MODULE_ALIAS_CRYPTO("ecb(aes)");
  20. MODULE_ALIAS_CRYPTO("cbc(aes)");
  21. MODULE_ALIAS_CRYPTO("ctr(aes)");
  22. MODULE_ALIAS_CRYPTO("xts(aes)");
  23. asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
  24. asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  25. int rounds, int blocks);
  26. asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  27. int rounds, int blocks);
  28. asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  29. int rounds, int blocks, u8 iv[]);
  30. asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  31. int rounds, int blocks, u8 ctr[], u8 final[]);
  32. asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
  33. int rounds, int blocks, u8 iv[]);
  34. asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
  35. int rounds, int blocks, u8 iv[]);
  36. struct aesbs_ctx {
  37. int rounds;
  38. u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
  39. };
  40. struct aesbs_cbc_ctx {
  41. struct aesbs_ctx key;
  42. struct crypto_cipher *enc_tfm;
  43. };
  44. struct aesbs_xts_ctx {
  45. struct aesbs_ctx key;
  46. struct crypto_cipher *tweak_tfm;
  47. };
  48. static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  49. unsigned int key_len)
  50. {
  51. struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  52. struct crypto_aes_ctx rk;
  53. int err;
  54. err = crypto_aes_expand_key(&rk, in_key, key_len);
  55. if (err)
  56. return err;
  57. ctx->rounds = 6 + key_len / 4;
  58. kernel_neon_begin();
  59. aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
  60. kernel_neon_end();
  61. return 0;
  62. }
  63. static int __ecb_crypt(struct skcipher_request *req,
  64. void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  65. int rounds, int blocks))
  66. {
  67. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  68. struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  69. struct skcipher_walk walk;
  70. int err;
  71. err = skcipher_walk_virt(&walk, req, true);
  72. kernel_neon_begin();
  73. while (walk.nbytes >= AES_BLOCK_SIZE) {
  74. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  75. if (walk.nbytes < walk.total)
  76. blocks = round_down(blocks,
  77. walk.stride / AES_BLOCK_SIZE);
  78. fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
  79. ctx->rounds, blocks);
  80. err = skcipher_walk_done(&walk,
  81. walk.nbytes - blocks * AES_BLOCK_SIZE);
  82. }
  83. kernel_neon_end();
  84. return err;
  85. }
  86. static int ecb_encrypt(struct skcipher_request *req)
  87. {
  88. return __ecb_crypt(req, aesbs_ecb_encrypt);
  89. }
  90. static int ecb_decrypt(struct skcipher_request *req)
  91. {
  92. return __ecb_crypt(req, aesbs_ecb_decrypt);
  93. }
  94. static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  95. unsigned int key_len)
  96. {
  97. struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
  98. struct crypto_aes_ctx rk;
  99. int err;
  100. err = crypto_aes_expand_key(&rk, in_key, key_len);
  101. if (err)
  102. return err;
  103. ctx->key.rounds = 6 + key_len / 4;
  104. kernel_neon_begin();
  105. aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
  106. kernel_neon_end();
  107. return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
  108. }
  109. static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
  110. {
  111. struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
  112. crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
  113. }
  114. static int cbc_encrypt(struct skcipher_request *req)
  115. {
  116. return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
  117. }
  118. static int cbc_decrypt(struct skcipher_request *req)
  119. {
  120. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  121. struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
  122. struct skcipher_walk walk;
  123. int err;
  124. err = skcipher_walk_virt(&walk, req, true);
  125. kernel_neon_begin();
  126. while (walk.nbytes >= AES_BLOCK_SIZE) {
  127. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  128. if (walk.nbytes < walk.total)
  129. blocks = round_down(blocks,
  130. walk.stride / AES_BLOCK_SIZE);
  131. aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
  132. ctx->key.rk, ctx->key.rounds, blocks,
  133. walk.iv);
  134. err = skcipher_walk_done(&walk,
  135. walk.nbytes - blocks * AES_BLOCK_SIZE);
  136. }
  137. kernel_neon_end();
  138. return err;
  139. }
  140. static int cbc_init(struct crypto_tfm *tfm)
  141. {
  142. struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
  143. ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
  144. if (IS_ERR(ctx->enc_tfm))
  145. return PTR_ERR(ctx->enc_tfm);
  146. return 0;
  147. }
  148. static void cbc_exit(struct crypto_tfm *tfm)
  149. {
  150. struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
  151. crypto_free_cipher(ctx->enc_tfm);
  152. }
  153. static int ctr_encrypt(struct skcipher_request *req)
  154. {
  155. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  156. struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  157. struct skcipher_walk walk;
  158. u8 buf[AES_BLOCK_SIZE];
  159. int err;
  160. err = skcipher_walk_virt(&walk, req, true);
  161. kernel_neon_begin();
  162. while (walk.nbytes > 0) {
  163. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  164. u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
  165. if (walk.nbytes < walk.total) {
  166. blocks = round_down(blocks,
  167. walk.stride / AES_BLOCK_SIZE);
  168. final = NULL;
  169. }
  170. aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
  171. ctx->rk, ctx->rounds, blocks, walk.iv, final);
  172. if (final) {
  173. u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
  174. u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
  175. if (dst != src)
  176. memcpy(dst, src, walk.total % AES_BLOCK_SIZE);
  177. crypto_xor(dst, final, walk.total % AES_BLOCK_SIZE);
  178. err = skcipher_walk_done(&walk, 0);
  179. break;
  180. }
  181. err = skcipher_walk_done(&walk,
  182. walk.nbytes - blocks * AES_BLOCK_SIZE);
  183. }
  184. kernel_neon_end();
  185. return err;
  186. }
  187. static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  188. unsigned int key_len)
  189. {
  190. struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
  191. int err;
  192. err = xts_verify_key(tfm, in_key, key_len);
  193. if (err)
  194. return err;
  195. key_len /= 2;
  196. err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
  197. if (err)
  198. return err;
  199. return aesbs_setkey(tfm, in_key, key_len);
  200. }
  201. static int xts_init(struct crypto_tfm *tfm)
  202. {
  203. struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
  204. ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
  205. if (IS_ERR(ctx->tweak_tfm))
  206. return PTR_ERR(ctx->tweak_tfm);
  207. return 0;
  208. }
  209. static void xts_exit(struct crypto_tfm *tfm)
  210. {
  211. struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
  212. crypto_free_cipher(ctx->tweak_tfm);
  213. }
  214. static int __xts_crypt(struct skcipher_request *req,
  215. void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  216. int rounds, int blocks, u8 iv[]))
  217. {
  218. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  219. struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
  220. struct skcipher_walk walk;
  221. int err;
  222. err = skcipher_walk_virt(&walk, req, true);
  223. crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
  224. kernel_neon_begin();
  225. while (walk.nbytes >= AES_BLOCK_SIZE) {
  226. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  227. if (walk.nbytes < walk.total)
  228. blocks = round_down(blocks,
  229. walk.stride / AES_BLOCK_SIZE);
  230. fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
  231. ctx->key.rounds, blocks, walk.iv);
  232. err = skcipher_walk_done(&walk,
  233. walk.nbytes - blocks * AES_BLOCK_SIZE);
  234. }
  235. kernel_neon_end();
  236. return err;
  237. }
  238. static int xts_encrypt(struct skcipher_request *req)
  239. {
  240. return __xts_crypt(req, aesbs_xts_encrypt);
  241. }
  242. static int xts_decrypt(struct skcipher_request *req)
  243. {
  244. return __xts_crypt(req, aesbs_xts_decrypt);
  245. }
  246. static struct skcipher_alg aes_algs[] = { {
  247. .base.cra_name = "__ecb(aes)",
  248. .base.cra_driver_name = "__ecb-aes-neonbs",
  249. .base.cra_priority = 250,
  250. .base.cra_blocksize = AES_BLOCK_SIZE,
  251. .base.cra_ctxsize = sizeof(struct aesbs_ctx),
  252. .base.cra_module = THIS_MODULE,
  253. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  254. .min_keysize = AES_MIN_KEY_SIZE,
  255. .max_keysize = AES_MAX_KEY_SIZE,
  256. .walksize = 8 * AES_BLOCK_SIZE,
  257. .setkey = aesbs_setkey,
  258. .encrypt = ecb_encrypt,
  259. .decrypt = ecb_decrypt,
  260. }, {
  261. .base.cra_name = "__cbc(aes)",
  262. .base.cra_driver_name = "__cbc-aes-neonbs",
  263. .base.cra_priority = 250,
  264. .base.cra_blocksize = AES_BLOCK_SIZE,
  265. .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx),
  266. .base.cra_module = THIS_MODULE,
  267. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  268. .base.cra_init = cbc_init,
  269. .base.cra_exit = cbc_exit,
  270. .min_keysize = AES_MIN_KEY_SIZE,
  271. .max_keysize = AES_MAX_KEY_SIZE,
  272. .walksize = 8 * AES_BLOCK_SIZE,
  273. .ivsize = AES_BLOCK_SIZE,
  274. .setkey = aesbs_cbc_setkey,
  275. .encrypt = cbc_encrypt,
  276. .decrypt = cbc_decrypt,
  277. }, {
  278. .base.cra_name = "__ctr(aes)",
  279. .base.cra_driver_name = "__ctr-aes-neonbs",
  280. .base.cra_priority = 250,
  281. .base.cra_blocksize = 1,
  282. .base.cra_ctxsize = sizeof(struct aesbs_ctx),
  283. .base.cra_module = THIS_MODULE,
  284. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  285. .min_keysize = AES_MIN_KEY_SIZE,
  286. .max_keysize = AES_MAX_KEY_SIZE,
  287. .chunksize = AES_BLOCK_SIZE,
  288. .walksize = 8 * AES_BLOCK_SIZE,
  289. .ivsize = AES_BLOCK_SIZE,
  290. .setkey = aesbs_setkey,
  291. .encrypt = ctr_encrypt,
  292. .decrypt = ctr_encrypt,
  293. }, {
  294. .base.cra_name = "__xts(aes)",
  295. .base.cra_driver_name = "__xts-aes-neonbs",
  296. .base.cra_priority = 250,
  297. .base.cra_blocksize = AES_BLOCK_SIZE,
  298. .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
  299. .base.cra_module = THIS_MODULE,
  300. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  301. .base.cra_init = xts_init,
  302. .base.cra_exit = xts_exit,
  303. .min_keysize = 2 * AES_MIN_KEY_SIZE,
  304. .max_keysize = 2 * AES_MAX_KEY_SIZE,
  305. .walksize = 8 * AES_BLOCK_SIZE,
  306. .ivsize = AES_BLOCK_SIZE,
  307. .setkey = aesbs_xts_setkey,
  308. .encrypt = xts_encrypt,
  309. .decrypt = xts_decrypt,
  310. } };
  311. static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
  312. static void aes_exit(void)
  313. {
  314. int i;
  315. for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
  316. if (aes_simd_algs[i])
  317. simd_skcipher_free(aes_simd_algs[i]);
  318. crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
  319. }
  320. static int __init aes_init(void)
  321. {
  322. struct simd_skcipher_alg *simd;
  323. const char *basename;
  324. const char *algname;
  325. const char *drvname;
  326. int err;
  327. int i;
  328. if (!(elf_hwcap & HWCAP_NEON))
  329. return -ENODEV;
  330. err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
  331. if (err)
  332. return err;
  333. for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
  334. if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
  335. continue;
  336. algname = aes_algs[i].base.cra_name + 2;
  337. drvname = aes_algs[i].base.cra_driver_name + 2;
  338. basename = aes_algs[i].base.cra_driver_name;
  339. simd = simd_skcipher_create_compat(algname, drvname, basename);
  340. err = PTR_ERR(simd);
  341. if (IS_ERR(simd))
  342. goto unregister_simds;
  343. aes_simd_algs[i] = simd;
  344. }
  345. return 0;
  346. unregister_simds:
  347. aes_exit();
  348. return err;
  349. }
  350. late_initcall(aes_init);
  351. module_exit(aes_exit);