#include <cassert>
#include <iostream>
#include <sstream>
#include <fstream>
#include <string>
using namespace std;

#include <NTL/ZZ.h>
using namespace NTL;

#include <rsa.h>


void rsaE::operator()( string & c, string const & m ) const 
{
  ZZ m2;
  conv(m2,m.c_str());

  ZZ c2;
  PowerMod(c2,m2,e,n);

  stringstream ss;
  ss << c2;
  c = ss.str();
}



void rsaD::init()
{
  n = p*q;

  ZZ n2 = (p-1)*(q-1);

  InvMod(d,e,n2);

//  cout << "d=" << d << endl;
}


void rsaD::operator()( string & m, string const & c ) const
{
  ZZ c2;
  conv(c2,c.c_str());

  ZZ m2;
  PowerMod(m2,c2,d,n);
  stringstream ss;
  ss << m2;
  m = ss.str();
}


void rsaGenKey::finde()
{
  for ( ; ; )
  {
    RandomBits(e,nbits);
    lpg.ensurenbits(e,nbits);

    if (GCD(e,n2)==1)
      return;
  }
}

rsaGenKey::rsaGenKey
(
  uintc _nbits,
  string const & _p,
  string const & _q
)
  : nbits(_nbits)
{
  conv(p,_p.c_str());
  conv(q,_q.c_str());

  lpg.findprime_sequentially(p,nbits);
  lpg.findprime_sequentially(q,nbits);

  genkey();
}

void rsaGenKey::genkey()
{
  n= p*q;
  n2 = (p-1)*(q-1);

  finde();

  blocksizecalc();

  //print();

  {
  ofstream f1("encrypt.txt");
  f1 << "encrypt=true" << endl;
  f1 << "n=" << n << endl;
  f1 << "e=" << e << endl;
  f1 << "blocksize=" << blocksize << endl;
  }

  {
  ofstream f2("decrypt.txt");
  f2 << "decrypt=true" << endl;
  f2 << "p=" << p << endl;
  f2 << "q=" << q << endl;
  f2 << "e=" << e << endl;
  }
}


rsaGenKey::rsaGenKey(uintc _nbits)
  : nbits(_nbits)
{
  assert(nbits>10);  // Really want large primes!

  lpg.findprime_randomly(p,nbits);
  lpg.findprime_randomly(q,nbits);
 
  genkey();
}

void rsaGenKey::print() const
{
  cout << SHOW(p) << endl;
  cout << SHOW(q) << endl;
  cout << SHOW(e) << endl;
  cout << SHOW(n) << endl;
  cout << SHOW(n2) << endl;
  cout << SHOW(nbits) << endl;
  cout << SHOW(blocksize) << endl;
}

void rsaGenKey::blocksizecalc()
{
  double x = nbits;
  double y;
//  y = 2.0*x*std::log10(2.0);
  y = x*std::log10(2.0)-1.0;
  y = std::floor(y);
  assert(y>1.0);
  blocksize = (uint)y;
}



