/*
 * Decompiled with CFR 0.152.
 */
package org.lsmp.djep.djep;

import java.io.PrintStream;
import java.util.Enumeration;
import java.util.Hashtable;
import org.lsmp.djep.djep.DJep;
import org.lsmp.djep.djep.DVariable;
import org.lsmp.djep.djep.DiffRulesI;
import org.lsmp.djep.djep.PartialDerivative;
import org.lsmp.djep.xjep.DeepCopyVisitor;
import org.lsmp.djep.xjep.NodeFactory;
import org.lsmp.djep.xjep.TreeUtils;
import org.lsmp.djep.xjep.XVariable;
import org.nfunk.jep.ASTConstant;
import org.nfunk.jep.ASTFunNode;
import org.nfunk.jep.ASTVarNode;
import org.nfunk.jep.Node;
import org.nfunk.jep.ParseException;
import org.nfunk.jep.ParserVisitor;

public class DifferentiationVisitor
extends DeepCopyVisitor {
    private static final boolean DEBUG = false;
    private DJep localDJep;
    private DJep globalDJep;
    private NodeFactory nf;
    private TreeUtils tu;
    Hashtable diffRules = new Hashtable();

    public DifferentiationVisitor(DJep dJep) {
        this.globalDJep = dJep;
    }

    void addDiffRule(DiffRulesI diffRulesI) {
        this.diffRules.put(diffRulesI.getName(), diffRulesI);
    }

    DiffRulesI getDiffRule(String string) {
        return (DiffRulesI)this.diffRules.get(string);
    }

    public void printDiffRules() {
        this.printDiffRules(System.out);
    }

    public void printDiffRules(PrintStream printStream) {
        Object object;
        String string;
        printStream.println("Standard Functions and their derivatives");
        Enumeration enumeration = this.globalDJep.getFunctionTable().keys();
        while (enumeration.hasMoreElements()) {
            string = (String)enumeration.nextElement();
            object = this.globalDJep.getFunctionTable().get(string);
            DiffRulesI diffRulesI = (DiffRulesI)this.diffRules.get(string);
            if (diffRulesI == null) {
                printStream.print(string + " No diff rules specified (" + object.getNumberOfParameters() + " arguments).");
            } else {
                printStream.print(((Object)diffRulesI).toString());
            }
            printStream.println();
        }
        enumeration = this.diffRules.keys();
        while (enumeration.hasMoreElements()) {
            string = (String)enumeration.nextElement();
            object = (DiffRulesI)this.diffRules.get(string);
            if (this.globalDJep.getFunctionTable().containsKey((Object)string)) continue;
            printStream.print(object.toString());
            printStream.println("\tnot in JEP function list");
        }
    }

    public Node differentiate(Node node, String string, DJep dJep) throws ParseException, IllegalArgumentException {
        this.localDJep = dJep;
        this.nf = dJep.getNodeFactory();
        this.tu = dJep.getTreeUtils();
        if (node == null) {
            throw new IllegalArgumentException("node parameter is null");
        }
        if (string == null) {
            throw new IllegalArgumentException("var parameter is null");
        }
        Node node2 = (Node)node.jjtAccept((ParserVisitor)this, (Object)string);
        return node2;
    }

    public Object visit(ASTFunNode aSTFunNode, Object object) throws ParseException {
        String string = aSTFunNode.getName();
        Node[] nodeArray = TreeUtils.getChildrenAsArray((Node)aSTFunNode);
        Node[] nodeArray2 = this.acceptChildrenAsArray((Node)aSTFunNode, object);
        if (aSTFunNode.getPFMC() instanceof DiffRulesI) {
            return ((DiffRulesI)aSTFunNode.getPFMC()).differentiate(aSTFunNode, (String)object, nodeArray, nodeArray2, this.localDJep);
        }
        DiffRulesI diffRulesI = (DiffRulesI)this.diffRules.get(string);
        if (diffRulesI != null) {
            return diffRulesI.differentiate(aSTFunNode, (String)object, nodeArray, nodeArray2, this.localDJep);
        }
        throw new ParseException("Sorry I don't know how to differentiate " + aSTFunNode + "\n");
    }

    public boolean isConstantVar(XVariable xVariable) {
        if (!xVariable.hasEquation()) {
            return true;
        }
        Node node = xVariable.getEquation();
        return node instanceof ASTConstant;
    }

    public Object visit(ASTVarNode aSTVarNode, Object object) throws ParseException {
        XVariable xVariable;
        String string = (String)object;
        XVariable xVariable2 = (XVariable)aSTVarNode.getVar();
        PartialDerivative partialDerivative = null;
        if (xVariable2 instanceof DVariable) {
            xVariable = (DVariable)xVariable2;
            if (string.equals(xVariable2.getName())) {
                return this.nf.buildConstantNode(this.tu.getONE());
            }
            if (this.isConstantVar(xVariable2)) {
                return this.nf.buildConstantNode(this.tu.getZERO());
            }
            partialDerivative = ((DVariable)xVariable).findDerivative(string, this.localDJep);
        } else if (xVariable2 instanceof PartialDerivative) {
            if (this.isConstantVar(xVariable2)) {
                return this.nf.buildConstantNode(this.tu.getZERO());
            }
            xVariable = (PartialDerivative)xVariable2;
            DVariable dVariable = ((PartialDerivative)xVariable).getRoot();
            partialDerivative = dVariable.findDerivative((PartialDerivative)xVariable, string, this.localDJep);
        } else {
            throw new ParseException("Encountered non differentiable variable");
        }
        xVariable = partialDerivative.getEquation();
        if (xVariable instanceof ASTVarNode) {
            return this.nf.buildVariableNode(((ASTVarNode)xVariable).getVar());
        }
        if (xVariable instanceof ASTConstant) {
            return this.nf.buildConstantNode(((ASTConstant)xVariable).getValue());
        }
        return this.nf.buildVariableNode(partialDerivative);
    }

    public Object visit(ASTConstant aSTConstant, Object object) throws ParseException {
        return this.nf.buildConstantNode(this.tu.getZERO());
    }
}

