C++实现RC4流加密算法

Category: 代码片段 | Tags: c++, rc4 | Source: Markdown ----------> Back to Wiki

C++ 11标准实现的RC4流加密算法,使用了很多C++ 11的新语法。

#include <iostream>
#include <string>
#include <array>
#include <vector>
#include <algorithm>
using namespace std;

class RC4 {
    public:
        explicit RC4(void) {};
        void reset(const vector<uint8_t> &key, size_t len);
        void crypt(const vector<uint8_t> &in, vector<uint8_t> &out, size_t len);
        ~RC4(void) {};

    private:
        array<uint8_t, 256> sbox;
        uint8_t idx1;
        uint8_t idx2;
        explicit RC4(const RC4&) = delete;
        explicit RC4(const RC4&&) = delete;
        const RC4& operator=(const RC4&) = delete;
        const RC4&& operator=(const RC4&&) = delete;
};

void RC4::reset(const vector<uint8_t> &key, size_t len) {
    uint8_t j = 0;

    for (auto i = 0; i < sbox.size(); i++)
        sbox[i] = i;
    idx1 = 0; idx2 = 0;

    for (auto i = 0; i < sbox.size(); i++) {
        j += sbox[i] + key[i % len];
        swap(sbox[i], sbox[j]);
    }
}

void RC4::crypt(const vector<uint8_t> &in, vector<uint8_t> &out, size_t len) {
    uint8_t j = 0;

    for (auto i = 0; i < len; i++) {
        idx1++; idx2 += sbox[idx1];
        swap(sbox[idx1], sbox[idx2]);
        j = sbox[idx1] + sbox[idx2];
        out[i] = in[i] ^ sbox[j];
    }
}

int main(int argc, char **argv) {
    if (argc != 3) {
        cout << "Usage: " << argv[0] << " <key> <message>" << endl;
        return 1;
    }

    auto keylen = string(argv[1]).length(), msglen = string(argv[2]).length();
    RC4 rc4; vector<uint8_t> key(keylen, 0), msg(msglen, 0);
    key.assign(argv[1], argv[1] + keylen);
    msg.assign(argv[2], argv[2] + msglen);

    cout << "message: ";
    for_each(msg.begin(), msg.end(), [] (uint8_t i) { cout << i; });
    cout << endl;
    rc4.reset(key, keylen);
    rc4.crypt(msg, msg, msglen);
    cout << "encrypt: ";
    for (auto i : msg) { cout << hex << (i >> 4); cout << hex << (i & 0x0f) << " "; }
    cout << endl;
    rc4.reset(key, keylen);
    rc4.crypt(msg, msg, msglen);
    cout << "decrypt: ";
    for_each(msg.begin(), msg.end(), [] (uint8_t i) { cout << i; });
    cout << endl;

    return 0;
}