/*****************************************************************************
 *
 * Copyright (c) 2008-14, Joachim Fellmuth, Holger Gross, Florian Greiner, 
 * Bettina Hünnemeyer, Paula Herber, Verena Klös, Timm Liebrenz, 
 * Tobias Pfeffer, Marcel Pockrandt, Rolf Schröder
 * Technische Universitaet Berlin, Software Engineering for Embedded
 * Systems Group, Ernst-Reuter-Platz 7, 10587 Berlin, Germany.
 * All rights reserved.
 * 
 * This file is part of STATE (SystemC to Timed Automata Transformation Engine).
 * 
 * STATE is free software: you can redistribute it and/or modify it under
 * the terms of the GNU General Public License as published by the Free
 * Software Foundation, either version 3 of the License, or (at your
 * option) any later version.
 * 
 * STATE is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with STATE.  If not, see <http://www.gnu.org/licenses/>.
 *
 *
 *  Please report any problems or bugs to: state@pes.tu-berlin.de
 *
 ****************************************************************************/

package de.tub.pes.sc2uppaal.optimization;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import de.tub.pes.sc2uppaal.engine.Engine;
import de.tub.pes.sc2uppaal.tamodel.Constants;
import de.tub.pes.sc2uppaal.tamodel.TAFunction;
import de.tub.pes.sc2uppaal.tamodel.TALocation;
import de.tub.pes.sc2uppaal.tamodel.TAModel;
import de.tub.pes.sc2uppaal.tamodel.TATemplate;
import de.tub.pes.sc2uppaal.tamodel.TATransition;
import de.tub.pes.sc2uppaal.tamodel.TAVariable;
import de.tub.pes.sc2uppaal.tamodel.expressions.TAFunctionCallExpression;
import de.tub.pes.sc2uppaal.tamodel.expressions.TARecvExpression;
import de.tub.pes.sc2uppaal.tamodel.expressions.TASendExpression;
import de.tub.pes.sc2uppaal.tamodel.expressions.TASendRecvExpression;
import de.tub.pes.syscir.sc_model.SCFunction;
import de.tub.pes.syscir.sc_model.SCParameter;
import de.tub.pes.syscir.sc_model.expressions.BinaryExpression;
import de.tub.pes.syscir.sc_model.expressions.Expression;
import de.tub.pes.syscir.sc_model.expressions.FunctionCallExpression;
import de.tub.pes.syscir.sc_model.expressions.SCVariableExpression;

/**
 * 
 * @author Pfeffer
 * 
 */
public class OmtOptimizer implements Optimizer {

	private static Logger logger = LogManager.getLogger(OmtOptimizer.class
			.getName());

	private static enum Modes {
		TEMPLATEREPLACE, CALLREPLACE
	};

	/**
	 * Mode used to optimize function templates.
	 */
	private static final Modes mode = Modes.CALLREPLACE;

	/**
	 * Function template optimization method, which replaces the contents of any
	 * function template with a simple two-locationed template, invoking the
	 * original function.
	 * 
	 * @param ta
	 * @param oscfs
	 */
	private void replaceTemplates(TAModel ta, List<SCFunction> oscfs) {
		// look for optimized templates
		for (SCFunction scf : oscfs) {
			TATemplate t = ta.getTemplate(scf.getSCClass().getName()
					+ Constants.PREFIX_DELIMITER + scf.getName());
			// reset locations and transitions
			TALocation initLoc = t.getInitLocation();
			t.setLocations(new LinkedList<TALocation>());
			t.setInitLocation(initLoc);
			TALocation calcLoc = t.createUrgentLocation();
			// create calculation transitions
			TATransition callTrans = new TATransition(initLoc, calcLoc,
					new TARecvExpression(null, t.getCtrlParameter().getName()));
			TATransition returnTrans = new TATransition(calcLoc, initLoc,
					new TASendExpression(null, t.getCtrlParameter().getName()));
			// get parameters for function calls
			LinkedList<Expression> paramExpressions = new LinkedList<Expression>();
			for (SCParameter param : scf.getParameters()) {
				param.getVar().setName(
						scf.getName() + Constants.PREFIX_DELIMITER
								+ Constants.LOCAL_FUNCTION_PARAM_KEYWORD
								+ Constants.PREFIX_DELIMITER
								+ param.getVar().getName());
				paramExpressions.add(new SCVariableExpression(null, param
						.getVar()));
			}
			// add function calls to return transition
			FunctionCallExpression fce = new FunctionCallExpression(null, scf,
					paramExpressions);
			if ("void".equals(scf.getReturnType())) {
				// void function
				returnTrans.addUpdateExpression(fce);
			} else {
				// result has to be stored in global variable
				TAFunctionCallExpression tafce = new TAFunctionCallExpression(
						null, scf, new LinkedList<Expression>(), null);
				BinaryExpression be = new BinaryExpression(null, tafce, "=",
						fce);
				returnTrans.addUpdateExpression(be);
			}
			// add calculation transitions
			t.addTransitions(callTrans, returnTrans);
		}
	}

	/**
	 * Function call optimization method, which replaces function calling
	 * transition with invocation of their original function.
	 * 
	 * @param ta
	 * @param oscfs
	 */
	private void replaceCalls(TAModel ta, List<SCFunction> oscfs) {
		// get all optimized template calling channels
		LinkedList<TAVariable> ctrlParams = new LinkedList<TAVariable>();
		for (SCFunction scf : oscfs) {
			TATemplate t = ta.getTemplate(scf.getSCClass().getName()
					+ Constants.PREFIX_DELIMITER + scf.getName());
			ctrlParams.add(t.getCtrlParameter());
		}
		// lookup calling templates
		for (TATemplate t : ta.getTemplates().values()) {
			HashMap<TATransition, TATransition> callReplacements = new HashMap<TATransition, TATransition>();
			List<TAVariable> calledCtrlChannels = t
					.getCalledCtrlChannelsFromParameters();
			for (TAVariable ctrlParam : ctrlParams) {
				if (calledCtrlChannels.contains(ctrlParam)) {
					// look for calling transitions
					for (TATransition trans : t.getTransitions()) {
						TASendRecvExpression sre = trans.getSync();
						if (sre != null && sre.contains(ctrlParam)
								&& sre.isSend()) {
							// check if sending transition is build correctly
							TALocation startLoc = trans.getStart();
							TALocation endLoc = trans.getEnd();
							List<TATransition> nextSteps = t
									.getOutgoingTransitions(endLoc);
							if (nextSteps.isEmpty() || nextSteps.size() > 1) {
								logger.info(
										"Found function calling transition with multiple results: {}. Continuing.",
										trans);
								continue;
							}
							TATransition nextStep = nextSteps.get(0);
							TATransition newTrans = new TATransition(startLoc,
									nextStep.getEnd());
							List<Expression> update = nextStep.getUpdate();
							if (update.isEmpty()) {
								logger.info(
										"Found function calling transtition with no update: {}. Continuing.",
										trans);
								continue;
							}
							Expression lastUpdate = update
									.get(update.size() - 1);
							// check if a function is called (void or other
							// case)
							if (lastUpdate instanceof FunctionCallExpression
									|| (lastUpdate instanceof BinaryExpression && ((BinaryExpression) lastUpdate)
											.getRight() instanceof FunctionCallExpression)) {
								newTrans.addUpdateExpression(lastUpdate);
							}
							callReplacements.put(trans, newTrans);
						}
					}
				}
			}
			// set new transitions, remove unnecessary locations
			for (Entry<TATransition, TATransition> replacement : callReplacements
					.entrySet()) {
				t.getLocations().remove(replacement.getKey().getEnd());
				t.addTransition(replacement.getValue());
			}
		}
	}

	/**
	 * Orders a list of TAFunctions by their dependencies.
	 * 
	 * @param tafs
	 * @return
	 */
	private List<TAFunction> orderByDependency(List<TAFunction> tafs) {
		assert tafs != null;
		// define dependency comparator
		Comparator<TAFunction> depsCompare = new Comparator<TAFunction>() {

			@Override
			public int compare(TAFunction o1, TAFunction o2) {
				// check for dependency
				if (o1.getCalledNames().contains(o2.getName())) {
					return 1;
				} else if (o2.getCalledNames().contains(o1.getName())) {
					return -1;
				}
				// else return zero
				return 0;
			}

		};
		Collections.sort(tafs, depsCompare);
		return tafs;
	}

	@Override
	public void run(TAModel ta) {
		if (!Engine.ALWAYS_USE_MEM_MODEL) {
			logger.error("This optimizer can be used currently only of the memory model is always on [see help for the corresponding flag].");
			return;
		}
		if (ta == null)
			throw new NullPointerException();

		// check all templates for possible function optimization
		ArrayList<SCFunction> scfs = new ArrayList<SCFunction>();
		for (TATemplate t : ta.getTemplates().values()) {
			SCFunction scf = t.getFunction();
			if (scf != null && !scf.getConsumesTime()
					&& !containsPEQorEventNotification(scf)) {
				scfs.add(scf);
			}
		}
		// check functions for calls to non-optimizable functions
		ArrayList<SCFunction> oscfs = new ArrayList<SCFunction>();
		for (SCFunction scf : scfs) {
			boolean optimizable = true;
			for (FunctionCallExpression fce : scf.getFunctionCalls()) {
				if (!scfs.contains(fce.getFunction())) {
					optimizable = false;
					break;
				}
			}
			if (optimizable) {
				oscfs.add(scf);
			}
		}
		// translate all optimizable functions to native uppaal code
		ArrayList<TAFunction> tafs = new ArrayList<TAFunction>();
		for (SCFunction scf : oscfs) {
			tafs.add(TAFunction.fromSCFunction(scf));
		}
		// append native functions to uppaal declaration
		ta.addFunctions(orderByDependency(tafs));

		// there are generally two ways to go from here. The first would be
		// template replacement. In this case, function templates should be
		// replaced by a simple call of their function, depending on its return
		// type. Every optimized template will hold two transitions afterwards.
		// The other way would be function call replacement. After the
		// translation of SCFunctions, all template calling transition will
		// already be replaced by function calls. All thats left to do is the
		// removal of the calling transitions.
		switch (mode) {
		case TEMPLATEREPLACE:
			logger.info("Using template replacement method.");
			replaceTemplates(ta, oscfs);
			break;
		case CALLREPLACE:
			logger.info("Using call replacement method.");
			replaceCalls(ta, oscfs);
			break;
		default:
			throw new AssertionError();
		}

		logger.info("Optimized {} functions.", oscfs.size());
	}

	private boolean containsPEQorEventNotification(SCFunction f) {
		return f.getEventNotifications().size() > 0;
	}

	private boolean isSocketFunctionCall(FunctionCallExpression fce) {
		SCFunction f = fce.getFunction();
		return isSocketFunction(f);
	}

	private boolean isSocketFunction(SCFunction f) {
		String name = f.getName();
		if (name.equals("nb_transport_fw") || name.equals("nb_transport_bw")
				|| name.equals("b_transport")) {
			return true;
		}
		return false;
	}
}
