/*******************************************************************************
  QuadraticWordEquationSolver.cc  solves a single quadratic word equation

     Algorithm from "Volker Diekert: Makanin's Algorithm, In M. Lothaire
                     (Ed.): Algebraic Combinatorics on Words, 2002."

 Copyright (C) 2006  Heiko Stamer <stamer@theory.informatik.uni-kassel.de>

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*******************************************************************************/

#include <iostream>
#include <vector>
#include <deque>
#include <map>
#include <string>
#include <algorithm>
#include <cassert>

// Sigma is the set of constants, Omega is the set of variables
const std::string Sigma = "abcdef", Omega = "xyzuvw";

typedef std::pair<std::string, std::string> Equation;

std::ostream& operator <<
	(std::ostream &out, const Equation &e)
{
	out << e.first << " = " << e.second;
	return out;
}

std::string eqstr
	(const Equation &e)
{
	return e.first + "_" + e.second;
}

size_t count_variables
	(const Equation &e)
{
	size_t cnt = 0;
	
	for (size_t i = 0; i < Omega.length(); i++)
	{
		size_t pos_first = 0, pos_second = 0;
		while ((pos_first = e.first.find(Omega[i], pos_first + 1)) != 
			e.first.npos)
				cnt++;
		while ((pos_second = e.second.find(Omega[i], pos_second + 1)) != 
			e.second.npos)
				cnt++;
	}
	
	return cnt;
}

void cancel
	(Equation &e)
{
	if (e.first != e.second)
	{
		while ((e.first.length() > 0) && (e.second.length() > 0) && 
			(e.first[0] == e.second[0]))
		{
			e.first = e.first.substr(1);
			e.second = e.second.substr(1);
		}
		while ((e.first.length() > 0) && (e.second.length() > 0) && 
			(e.first[e.first.length() - 1] == e.second[e.second.length() - 1]))
		{
			e.first = e.first.substr(0, e.first.length() - 1);
			e.second = e.second.substr(0, e.second.length() - 1);
		}
	}
}

void solve_first
	(Equation &e, Equation &edge)
{
	size_t replacement_position_first, replacement_position_second;
	std::string redex;
	
	if (Omega.find(e.first[0]) != Omega.npos)
	{
		if ((Sigma.find(e.second[0]) != Sigma.npos) ||
			(Omega.find(e.second[0]) != Omega.npos))
		{
			replacement_position_first = e.first.find(e.first[0], 1);
			replacement_position_second = e.second.find(e.first[0], 1);
			redex += e.second[0];
			redex += e.first[0];
			edge.first = e.first[0], edge.second = redex;
			if (replacement_position_first != e.first.npos)
				e.first.replace(replacement_position_first, 1, redex);
			if (replacement_position_second != e.second.npos)
				e.second.replace(replacement_position_second, 1, redex);
			e.first.replace(0, 1, redex);
		}
	}
}

void solve_second
	(Equation &e, Equation &edge)
{
	size_t replacement_position_first, replacement_position_second;
	std::string redex;
	
	if (Omega.find(e.second[0]) != Omega.npos)
	{
		if ((Sigma.find(e.first[0]) != Sigma.npos) ||
			(Omega.find(e.first[0]) != Omega.npos))
		{
			replacement_position_first = e.first.find(e.second[0], 1);
			replacement_position_second = e.second.find(e.second[0], 1);
			redex += e.first[0];
			redex += e.second[0];
			edge.first = e.second[0], edge.second = redex;
			if (replacement_position_first != e.first.npos)
				e.first.replace(replacement_position_first, 1, redex);
			if (replacement_position_second != e.second.npos)
				e.second.replace(replacement_position_second, 1, redex);
			e.second.replace(0, 1, redex);
		}
	}
}

void solve_epsilon
	(Equation &e, Equation &edge, const char &variable)
{
	size_t replacement_position_first = 0, replacement_position_second = 0;
	while ((replacement_position_first = e.first.find(variable, 
		replacement_position_first)) != e.first.npos)
			e.first.replace(replacement_position_first, 1, "");
	while ((replacement_position_second = e.second.find(variable, 
		replacement_position_second)) != e.second.npos)
			e.second.replace(replacement_position_second, 1, "");
	edge.first = variable, edge.second = "";
}

void update
	(Equation &o, Equation &e, Equation edge, std::vector<Equation> &V,
	 std::map<std::pair<Equation, Equation>, std::vector<Equation> > &E)
{
	if (std::find(V.begin(), V.end(), e) == V.end())
	{
		std::cerr << e << std::endl;
		V.push_back(e);
	}
	
	if (edge.first != "")
	{
		if (E.find(std::pair<Equation, Equation>(o, e)) != E.end())
		{
			if (std::find(E[std::pair<Equation, Equation>(o, e)].begin(),
				E[std::pair<Equation, Equation>(o, e)].end(), edge) == 
				E[std::pair<Equation, Equation>(o, e)].end())
			{
				std::cerr << "(" << e << ") [ " << edge << " ]" << std::endl;
				E[std::pair<Equation, Equation>(o, e)].push_back(edge);
			}
		}
		else
		{
			E[std::pair<Equation, Equation>(o, e)] = std::vector<Equation>();
			E[std::pair<Equation, Equation>(o, e)].push_back(edge);
			std::cerr << "(" << e << ") [ " << edge << " ]" << std::endl;
		}
	}
}

void solve
	(std::vector<Equation> &V,
	 const std::map< std::pair<Equation, Equation>, std::vector<Equation> > &E,
	 std::vector< std::map<std::string, std::string> > &PHI)
{
	std::vector< std::vector< std::pair<Equation, std::string> > > S;
	std::vector<Equation> V2;
	
	S.push_back(std::vector< std::pair<Equation, std::string> >());
	S[0].push_back(std::pair<Equation, std::string>(V[0], ""));
	V2.push_back(V[0]);
	while (!S.empty())
	{
		std::vector< std::vector< std::pair<Equation, std::string> > > S2;
		
		for (std::vector< std::vector< std::pair<Equation, std::string> > 
			>::const_iterator si = S.begin(); si != S.end(); ++si)
		{
			Equation f = (si->back()).first;
			for(std::map<std::pair<Equation, Equation>, std::vector<Equation> 
				>::const_iterator ci = E.begin(); ci != E.end(); ++ci)
			{
				if (ci->first.first == f)
				{
					Equation t = ci->first.second;
					Equation edge = (ci->second)[0];
					std::pair<Equation, std::string> p =
						std::pair<Equation, std::string>(t, eqstr(edge));
					// solution found
					if (t.first == t.second)
					{
						std::deque<std::string> Q;
						std::cerr << "Solution found at " << t << std::endl;
						for (std::vector< std::pair<Equation, std::string> 
							>::const_iterator ti = si->begin(); ti != si->end(); ++ti)
						{
							std::cerr << ti->second << std::endl;
							Q.push_front(ti->second);
							if (std::find(V2.begin(), V2.end(), ti->first) == 
								V2.end())
									V2.push_back(ti->first);
						}
						if (std::find(V2.begin(), V2.end(), t) == V2.end())
							V2.push_back(t);
						std::cerr << eqstr(edge) << std::endl;
						Q.push_front(eqstr(edge));
						
						// compute phi
						std::map<std::string, std::string> phi;
						for (size_t i = 0; i < Omega.length(); i++)
						{
							std::string idx = "";
							idx += Omega[i];
							phi[idx] = "*";
						}
						for (std::deque<std::string>::const_iterator qi =
							Q.begin(); qi != Q.end(); ++qi)
						{
							std::string st = *qi;
							std::string st0 = "", st2 = "";
							st0 += st[0], st2 += st[2];
							if (st.length() == 2)
								phi[st0] = "";
							else if ((st.length() == 4) && (st[0] == st[3]))
							{
								// constant
								if (Omega.find(st2) == Omega.npos)
								{
									phi[st0] = st2 + phi[st0];
								} // variable
								else
								{
									phi[st0] = phi[st2] + phi[st0];
								}
							}
						}
						PHI.push_back(phi);
					} // avoid loops
					else if (std::find(si->begin(), si->end(), p) == si->end())
					{
						std::vector< std::pair<Equation, std::string> > s(*si);
						s.push_back(p);
						S2.push_back(s);
					}
				}
			}
		}
		S.swap(S2);
	}
	V.swap(V2);
}

void cleanup
	(std::vector<Equation> &V, 
	 std::map<std::pair<Equation, Equation>, std::vector<Equation> > &E)
{
	size_t v;
	
	do
	{
		std::vector<Equation> V2 = std::vector<Equation>(V);
		std::map<std::pair<Equation, Equation>, std::vector<Equation> > E2 =
			std::map<std::pair<Equation, Equation>, std::vector<Equation> >(E);
		
		v = V.size();
		for (std::vector<Equation>::const_iterator ci = V2.begin();
			ci != V2.end(); ++ci)
		{
			if (ci->first != ci->second)
			{
				bool rm = true;
				for (std::map<std::pair<Equation, Equation>, 
					std::vector<Equation> >::const_iterator mi = E2.begin();
					mi != E2.end(); ++mi)
				{
					if ((*ci == mi->first.first) && 
						(mi->first.first != mi->first.second))
					{
						rm = false;
						break;
					}
				}
				
				// remove node?
				if (rm)
				{
					std::cerr << "rm node " << *ci << std::endl;
					V.erase(std::find(V.begin(), V.end(), *ci));
					for (std::map<std::pair<Equation, Equation>, 
						std::vector<Equation> >::const_iterator mi = 
						E2.begin(); mi != E2.end(); ++mi)
					{
						if (*ci == mi->first.second)
						{
							std::cerr << "rm edge " << mi->first.first <<
								" -> " << mi->first.second << std::endl;
							E.erase(mi->first);
						}
					}
				}
			}
		}
		std::cerr << "done1 " << v << " vs " << V.size() << std::endl;
		
		// remove remaining not connected edges
		for (std::map<std::pair<Equation, Equation>, 
			std::vector<Equation> >::const_iterator mi = 
			E2.begin(); mi != E2.end(); ++mi)
		{
			if ((std::find(V.begin(), V.end(), mi->first.first) == V.end())
				|| (std::find(V.begin(), V.end(), mi->first.second) == V.end()))
			{
				std::cerr << "rm edge " << mi->first.first <<
					" -> " << mi->first.second << std::endl;
				E.erase(mi->first);
			}
		}
	}
	while (v > V.size());
}

void dot
	(const std::vector<Equation> &V,
	 const std::map< std::pair<Equation, Equation>, std::vector<Equation> > &E,
	 const std::vector< std::map<std::string, std::string> > &PHI)
{
	std::vector<Equation> R;
	
	std::cout << "digraph QuadraticWordEquationSolver {" << std::endl;
	std::cout << "  /************* SOLUTIONS ************** " << std::endl;
	for (std::vector< std::map<std::string, std::string> >::const_iterator
		pi = PHI.begin(); pi != PHI.end(); ++pi)
	{
		bool neq = true;
		for (std::vector< std::map<std::string, std::string> >::const_iterator
			pi2 = PHI.begin(); pi2 != pi; ++pi2)
		{
			if (*pi == *pi2)
				neq = false;
		}
		if (neq)
		{
			std::cout << "    ";
			for (std::map<std::string, std::string>::const_iterator
				mi = pi->begin(); mi != pi->end(); ++mi)
			{
				std::cout << "phi(" << mi->first << ") = " << mi->second << " ";
			}
			std::cout << std::endl;
		}
	}
	std::cout << "   **************************************/" << std::endl;
	for (std::vector<Equation>::const_iterator ci = V.begin();
		ci != V.end(); ++ci)
	{
		Equation e = *ci;
		if (e.first == e.second)
			std::cout << "    " << eqstr(e) << " [ label = \"" << e << 
				"\" style = bold shape = box ];" << std::endl;
		else if (count_variables(e) == 0)
		{
			R.push_back(e);
		}
		else
			std::cout << "    " << eqstr(e) << " [ label = \"" << e << 
				"\" shape = box ];" << std::endl;
	}
	for (std::map<std::pair<Equation, Equation>, 
		std::vector<Equation> >::const_iterator ci = E.begin(); 
		ci != E.end(); ++ci)
	{
		Equation f = ci->first.first, t = ci->first.second;
		
		if ((std::find(R.begin(), R.end(), f) == R.end()) &&
			(std::find(R.begin(), R.end(), t) == R.end()))
		{
			for (std::vector<Equation>::const_iterator ci2 = 
				ci->second.begin(); ci2 != ci->second.end(); ++ci2)
			{
				std::cout << "    " << eqstr(f) << " -> " << eqstr(t) << 
					" [ label = \"" << *ci2 << "\" ];" << std::endl;
			}
		}
	}
	std::cout << "}" << std::endl;
}

int main
	(int argc, char **argv)
{
	std::vector<Equation> V;
	std::map<std::pair<Equation, Equation>, std::vector<Equation> > E;
	size_t v;
	
	if (argc != 3)
	{
		std::cout << "Usage: " << argv[0] << " <LHS> <RHS>" << std::endl;
		return -1;
	}
	
	//V.push_back(Equation("abxcy", "ycxba"));
	//V.push_back(Equation("zabbz", "aabxaxabbyaby"));
	V.push_back(Equation(argv[1], argv[2]));
	do
	{
		std::vector<Equation> V2;
		
		// solve_first
		for (std::vector<Equation>::const_iterator ci = V.begin();
			ci != V.end(); ++ci)
		{
			Equation o = *ci, e = *ci;
			Equation edge;
			
			cancel(e);
			solve_first(e, edge);
			cancel(e);
			
			update(o, e, edge, V2, E);
		}
		
		// solve_second
		for (std::vector<Equation>::const_iterator ci = V.begin();
			ci != V.end(); ++ci)
		{
			Equation o = *ci, e = *ci;
			Equation edge;
			
			cancel(e);
			solve_second(e, edge);
			cancel(e);
			
			update(o, e, edge, V2, E);
		}
		
		// solve_epsilon
		for (std::vector<Equation>::const_iterator ci = V.begin();
			ci != V.end(); ++ci)
		{
			for (size_t i = 0; i < Omega.length(); i++)
			{
				if ((ci->first.find(Omega[i]) != ci->first.npos) ||
					(ci->second.find(Omega[i]) != ci->second.npos))
				{
					Equation o = *ci, e = *ci;
					Equation edge;
					
					cancel(e);
					solve_epsilon(e, edge, Omega[i]);
					cancel(e);
					
					update(o, e, edge, V2, E);
				}
			}
		}
		
		// insert new nodes
		v = V.size();
		for (std::vector<Equation>::const_iterator ci = V2.begin();
			ci != V2.end(); ++ci)
		{
			if (std::find(V.begin(), V.end(), *ci) == V.end())
				V.push_back(*ci);
		}
	}
	while (V.size() > v);
	std::cerr << "|V| = " << V.size() << ", |E| = " << E.size() << std::endl;
	
	// remove superfluous nodes and edges
	cleanup(V, E);
	std::cerr << "|V| = " << V.size() << ", |E| = " << E.size() << std::endl;
	
	// breath-first search for minimal solutions
	std::vector< std::map<std::string, std::string> > PHI;
	solve(V, E, PHI);
	std::cerr << "|V| = " << V.size() << ", |E| = " << E.size() << std::endl;
	cleanup(V, E);
	std::cerr << "|V| = " << V.size() << ", |E| = " << E.size() << std::endl;
	
	// ouputs the solutions and the graph (dot format)
	dot(V, E, PHI);
	
	return 0;
}

