aes-neonbs-glue.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. /*
  2. * Bit sliced AES using NEON instructions
  3. *
  4. * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
  5. *
  6. * This program is free software; you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License version 2 as
  8. * published by the Free Software Foundation.
  9. */
  10. #include <asm/neon.h>
  11. #include <crypto/aes.h>
  12. #include <crypto/cbc.h>
  13. #include <crypto/internal/simd.h>
  14. #include <crypto/internal/skcipher.h>
  15. #include <crypto/xts.h>
  16. #include <linux/module.h>
  17. MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  18. MODULE_LICENSE("GPL v2");
  19. MODULE_ALIAS_CRYPTO("ecb(aes)");
  20. MODULE_ALIAS_CRYPTO("cbc(aes)");
  21. MODULE_ALIAS_CRYPTO("ctr(aes)");
  22. MODULE_ALIAS_CRYPTO("xts(aes)");
  23. asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
  24. asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  25. int rounds, int blocks);
  26. asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  27. int rounds, int blocks);
  28. asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  29. int rounds, int blocks, u8 iv[]);
  30. asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  31. int rounds, int blocks, u8 ctr[], u8 final[]);
  32. asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
  33. int rounds, int blocks, u8 iv[]);
  34. asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
  35. int rounds, int blocks, u8 iv[]);
  36. struct aesbs_ctx {
  37. int rounds;
  38. u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
  39. };
  40. struct aesbs_cbc_ctx {
  41. struct aesbs_ctx key;
  42. struct crypto_cipher *enc_tfm;
  43. };
  44. struct aesbs_xts_ctx {
  45. struct aesbs_ctx key;
  46. struct crypto_cipher *tweak_tfm;
  47. };
  48. static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  49. unsigned int key_len)
  50. {
  51. struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  52. struct crypto_aes_ctx rk;
  53. int err;
  54. err = crypto_aes_expand_key(&rk, in_key, key_len);
  55. if (err)
  56. return err;
  57. ctx->rounds = 6 + key_len / 4;
  58. kernel_neon_begin();
  59. aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
  60. kernel_neon_end();
  61. return 0;
  62. }
  63. static int __ecb_crypt(struct skcipher_request *req,
  64. void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  65. int rounds, int blocks))
  66. {
  67. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  68. struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  69. struct skcipher_walk walk;
  70. int err;
  71. err = skcipher_walk_virt(&walk, req, true);
  72. kernel_neon_begin();
  73. while (walk.nbytes >= AES_BLOCK_SIZE) {
  74. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  75. if (walk.nbytes < walk.total)
  76. blocks = round_down(blocks,
  77. walk.stride / AES_BLOCK_SIZE);
  78. fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
  79. ctx->rounds, blocks);
  80. err = skcipher_walk_done(&walk,
  81. walk.nbytes - blocks * AES_BLOCK_SIZE);
  82. }
  83. kernel_neon_end();
  84. return err;
  85. }
  86. static int ecb_encrypt(struct skcipher_request *req)
  87. {
  88. return __ecb_crypt(req, aesbs_ecb_encrypt);
  89. }
  90. static int ecb_decrypt(struct skcipher_request *req)
  91. {
  92. return __ecb_crypt(req, aesbs_ecb_decrypt);
  93. }
  94. static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  95. unsigned int key_len)
  96. {
  97. struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
  98. struct crypto_aes_ctx rk;
  99. int err;
  100. err = crypto_aes_expand_key(&rk, in_key, key_len);
  101. if (err)
  102. return err;
  103. ctx->key.rounds = 6 + key_len / 4;
  104. kernel_neon_begin();
  105. aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
  106. kernel_neon_end();
  107. return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
  108. }
  109. static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
  110. {
  111. struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
  112. crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
  113. }
  114. static int cbc_encrypt(struct skcipher_request *req)
  115. {
  116. return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
  117. }
  118. static int cbc_decrypt(struct skcipher_request *req)
  119. {
  120. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  121. struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
  122. struct skcipher_walk walk;
  123. int err;
  124. err = skcipher_walk_virt(&walk, req, true);
  125. kernel_neon_begin();
  126. while (walk.nbytes >= AES_BLOCK_SIZE) {
  127. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  128. if (walk.nbytes < walk.total)
  129. blocks = round_down(blocks,
  130. walk.stride / AES_BLOCK_SIZE);
  131. aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
  132. ctx->key.rk, ctx->key.rounds, blocks,
  133. walk.iv);
  134. err = skcipher_walk_done(&walk,
  135. walk.nbytes - blocks * AES_BLOCK_SIZE);
  136. }
  137. kernel_neon_end();
  138. return err;
  139. }
  140. static int cbc_init(struct crypto_tfm *tfm)
  141. {
  142. struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
  143. ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
  144. return PTR_ERR_OR_ZERO(ctx->enc_tfm);
  145. }
  146. static void cbc_exit(struct crypto_tfm *tfm)
  147. {
  148. struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
  149. crypto_free_cipher(ctx->enc_tfm);
  150. }
  151. static int ctr_encrypt(struct skcipher_request *req)
  152. {
  153. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  154. struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  155. struct skcipher_walk walk;
  156. u8 buf[AES_BLOCK_SIZE];
  157. int err;
  158. err = skcipher_walk_virt(&walk, req, true);
  159. kernel_neon_begin();
  160. while (walk.nbytes > 0) {
  161. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  162. u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
  163. if (walk.nbytes < walk.total) {
  164. blocks = round_down(blocks,
  165. walk.stride / AES_BLOCK_SIZE);
  166. final = NULL;
  167. }
  168. aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
  169. ctx->rk, ctx->rounds, blocks, walk.iv, final);
  170. if (final) {
  171. u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
  172. u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
  173. crypto_xor_cpy(dst, src, final,
  174. walk.total % AES_BLOCK_SIZE);
  175. err = skcipher_walk_done(&walk, 0);
  176. break;
  177. }
  178. err = skcipher_walk_done(&walk,
  179. walk.nbytes - blocks * AES_BLOCK_SIZE);
  180. }
  181. kernel_neon_end();
  182. return err;
  183. }
  184. static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  185. unsigned int key_len)
  186. {
  187. struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
  188. int err;
  189. err = xts_verify_key(tfm, in_key, key_len);
  190. if (err)
  191. return err;
  192. key_len /= 2;
  193. err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
  194. if (err)
  195. return err;
  196. return aesbs_setkey(tfm, in_key, key_len);
  197. }
  198. static int xts_init(struct crypto_tfm *tfm)
  199. {
  200. struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
  201. ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
  202. return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
  203. }
  204. static void xts_exit(struct crypto_tfm *tfm)
  205. {
  206. struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
  207. crypto_free_cipher(ctx->tweak_tfm);
  208. }
  209. static int __xts_crypt(struct skcipher_request *req,
  210. void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  211. int rounds, int blocks, u8 iv[]))
  212. {
  213. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  214. struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
  215. struct skcipher_walk walk;
  216. int err;
  217. err = skcipher_walk_virt(&walk, req, true);
  218. crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
  219. kernel_neon_begin();
  220. while (walk.nbytes >= AES_BLOCK_SIZE) {
  221. unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  222. if (walk.nbytes < walk.total)
  223. blocks = round_down(blocks,
  224. walk.stride / AES_BLOCK_SIZE);
  225. fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
  226. ctx->key.rounds, blocks, walk.iv);
  227. err = skcipher_walk_done(&walk,
  228. walk.nbytes - blocks * AES_BLOCK_SIZE);
  229. }
  230. kernel_neon_end();
  231. return err;
  232. }
  233. static int xts_encrypt(struct skcipher_request *req)
  234. {
  235. return __xts_crypt(req, aesbs_xts_encrypt);
  236. }
  237. static int xts_decrypt(struct skcipher_request *req)
  238. {
  239. return __xts_crypt(req, aesbs_xts_decrypt);
  240. }
  241. static struct skcipher_alg aes_algs[] = { {
  242. .base.cra_name = "__ecb(aes)",
  243. .base.cra_driver_name = "__ecb-aes-neonbs",
  244. .base.cra_priority = 250,
  245. .base.cra_blocksize = AES_BLOCK_SIZE,
  246. .base.cra_ctxsize = sizeof(struct aesbs_ctx),
  247. .base.cra_module = THIS_MODULE,
  248. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  249. .min_keysize = AES_MIN_KEY_SIZE,
  250. .max_keysize = AES_MAX_KEY_SIZE,
  251. .walksize = 8 * AES_BLOCK_SIZE,
  252. .setkey = aesbs_setkey,
  253. .encrypt = ecb_encrypt,
  254. .decrypt = ecb_decrypt,
  255. }, {
  256. .base.cra_name = "__cbc(aes)",
  257. .base.cra_driver_name = "__cbc-aes-neonbs",
  258. .base.cra_priority = 250,
  259. .base.cra_blocksize = AES_BLOCK_SIZE,
  260. .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx),
  261. .base.cra_module = THIS_MODULE,
  262. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  263. .base.cra_init = cbc_init,
  264. .base.cra_exit = cbc_exit,
  265. .min_keysize = AES_MIN_KEY_SIZE,
  266. .max_keysize = AES_MAX_KEY_SIZE,
  267. .walksize = 8 * AES_BLOCK_SIZE,
  268. .ivsize = AES_BLOCK_SIZE,
  269. .setkey = aesbs_cbc_setkey,
  270. .encrypt = cbc_encrypt,
  271. .decrypt = cbc_decrypt,
  272. }, {
  273. .base.cra_name = "__ctr(aes)",
  274. .base.cra_driver_name = "__ctr-aes-neonbs",
  275. .base.cra_priority = 250,
  276. .base.cra_blocksize = 1,
  277. .base.cra_ctxsize = sizeof(struct aesbs_ctx),
  278. .base.cra_module = THIS_MODULE,
  279. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  280. .min_keysize = AES_MIN_KEY_SIZE,
  281. .max_keysize = AES_MAX_KEY_SIZE,
  282. .chunksize = AES_BLOCK_SIZE,
  283. .walksize = 8 * AES_BLOCK_SIZE,
  284. .ivsize = AES_BLOCK_SIZE,
  285. .setkey = aesbs_setkey,
  286. .encrypt = ctr_encrypt,
  287. .decrypt = ctr_encrypt,
  288. }, {
  289. .base.cra_name = "__xts(aes)",
  290. .base.cra_driver_name = "__xts-aes-neonbs",
  291. .base.cra_priority = 250,
  292. .base.cra_blocksize = AES_BLOCK_SIZE,
  293. .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
  294. .base.cra_module = THIS_MODULE,
  295. .base.cra_flags = CRYPTO_ALG_INTERNAL,
  296. .base.cra_init = xts_init,
  297. .base.cra_exit = xts_exit,
  298. .min_keysize = 2 * AES_MIN_KEY_SIZE,
  299. .max_keysize = 2 * AES_MAX_KEY_SIZE,
  300. .walksize = 8 * AES_BLOCK_SIZE,
  301. .ivsize = AES_BLOCK_SIZE,
  302. .setkey = aesbs_xts_setkey,
  303. .encrypt = xts_encrypt,
  304. .decrypt = xts_decrypt,
  305. } };
  306. static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
  307. static void aes_exit(void)
  308. {
  309. int i;
  310. for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
  311. if (aes_simd_algs[i])
  312. simd_skcipher_free(aes_simd_algs[i]);
  313. crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
  314. }
  315. static int __init aes_init(void)
  316. {
  317. struct simd_skcipher_alg *simd;
  318. const char *basename;
  319. const char *algname;
  320. const char *drvname;
  321. int err;
  322. int i;
  323. if (!(elf_hwcap & HWCAP_NEON))
  324. return -ENODEV;
  325. err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
  326. if (err)
  327. return err;
  328. for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
  329. if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
  330. continue;
  331. algname = aes_algs[i].base.cra_name + 2;
  332. drvname = aes_algs[i].base.cra_driver_name + 2;
  333. basename = aes_algs[i].base.cra_driver_name;
  334. simd = simd_skcipher_create_compat(algname, drvname, basename);
  335. err = PTR_ERR(simd);
  336. if (IS_ERR(simd))
  337. goto unregister_simds;
  338. aes_simd_algs[i] = simd;
  339. }
  340. return 0;
  341. unregister_simds:
  342. aes_exit();
  343. return err;
  344. }
  345. late_initcall(aes_init);
  346. module_exit(aes_exit);