/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import java.util.ArrayList;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.Transform;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class ReorgOp
extends MultiThreadedHop {
    public static boolean FORCE_DIST_SORT_INDEXES = false;
    private Types.ReOrgOp _op;

    private ReorgOp() {
    }

    public ReorgOp(String l, Types.DataType dt, Types.ValueType vt, Types.ReOrgOp o, Hop inp) {
        super(l, dt, vt);
        this._op = o;
        this.getInput().add(0, inp);
        inp.getParent().add(this);
        this.refreshSizeInformation();
    }

    public ReorgOp(String l, Types.DataType dt, Types.ValueType vt, Types.ReOrgOp o, ArrayList<Hop> inp) {
        super(l, dt, vt);
        this._op = o;
        for (int i = 0; i < inp.size(); ++i) {
            Hop in = inp.get(i);
            this.getInput().add(i, in);
            in.getParent().add(this);
        }
        this.refreshSizeInformation();
    }

    @Override
    public void checkArity() {
        int sz = this._input.size();
        switch (this._op) {
            case TRANS: 
            case DIAG: 
            case REV: {
                HopsException.check(sz == 1, this, "should have arity 1 for op %s but has arity %d", new Object[]{this._op, sz});
                break;
            }
            case RESHAPE: 
            case SORT: {
                HopsException.check(sz == 5, this, "should have arity 5 for op %s but has arity %d", new Object[]{this._op, sz});
                break;
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + this._op + "'.");
            }
        }
    }

    public Types.ReOrgOp getOp() {
        return this._op;
    }

    @Override
    public String getOpString() {
        return "r(" + this._op.toString() + ")";
    }

    @Override
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        switch (this._op) {
            case TRANS: {
                if (this.getDim1() == 1L && this.getDim2() == 1L) {
                    return false;
                }
                return !(this.getInput().get(0) instanceof ReorgOp) || ((ReorgOp)this.getInput().get(0)).getOp() != Types.ReOrgOp.TRANS;
            }
            case RESHAPE: {
                return true;
            }
            case DIAG: 
            case REV: 
            case SORT: {
                return false;
            }
        }
        throw new RuntimeException("Unsupported operator:" + this._op.name());
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return this._op == Types.ReOrgOp.TRANS || this._op == Types.ReOrgOp.SORT;
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        Types.ExecType et = this.optFindExecType();
        switch (this._op) {
            case TRANS: {
                Lop lin = this.getInput().get(0).constructLops();
                if (lin instanceof Transform && ((Transform)lin).getOp() == Types.ReOrgOp.TRANS) {
                    this.setLops(lin.getInputs().get(0));
                    break;
                }
                if (this.getDim1() == 1L && this.getDim2() == 1L) {
                    this.setLops(lin);
                    break;
                }
                int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                Transform transform1 = new Transform(lin, this._op, this.getDataType(), this.getValueType(), et, k);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            case DIAG: 
            case REV: {
                Transform transform1 = new Transform(this.getInput().get(0).constructLops(), this._op, this.getDataType(), this.getValueType(), et);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            case RESHAPE: {
                Lop[] linputs = new Lop[5];
                for (int i = 0; i < 5; ++i) {
                    linputs[i] = this.getInput().get(i).constructLops();
                }
                this._outputEmptyBlocks = et == Types.ExecType.SPARK && !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
                Transform transform1 = new Transform(linputs, this._op, this.getDataType(), this.getValueType(), this._outputEmptyBlocks, et);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            case SORT: {
                Transform transform1;
                Lop[] linputs = new Lop[4];
                for (int i = 0; i < 4; ++i) {
                    linputs[i] = this.getInput().get(i).constructLops();
                }
                Hop by = this.getInput().get(2);
                if (et == Types.ExecType.SPARK) {
                    boolean sortRewrite = !FORCE_DIST_SORT_INDEXES && this.isSortSPRewriteApplicable() && by.getDataType().isScalar();
                    transform1 = new Transform(linputs, Types.ReOrgOp.SORT, this.getDataType(), this.getValueType(), et, sortRewrite, 1);
                } else {
                    int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                    transform1 = new Transform(linputs, Types.ReOrgOp.SORT, this.getDataType(), this.getValueType(), et, false, k);
                }
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + this._op + "'.");
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    @Override
    public void computeMemEstimate(MemoTable memo) {
        if (this._op == Types.ReOrgOp.TRANS && this.getInput().get(0).isCompressedOutput()) {
            this._outputMemEstimate = this.getInput().get(0).getCompressedSize();
        } else {
            super.computeMemEstimate(memo);
        }
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        Hop ixreturn;
        if (this._op == Types.ReOrgOp.SORT && (!((ixreturn = this.getInput().get(3)) instanceof LiteralOp) || HopRewriteUtils.getBooleanValueSafe((LiteralOp)ixreturn) || dim2 != 1L && nnz != 0L)) {
            return dim1 * 4L;
        }
        return 0.0;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        MatrixCharacteristics ret = null;
        Hop input = this.getInput().get(0);
        DataCharacteristics dc = memo.getAllInputStats(input);
        switch (this._op) {
            case TRANS: {
                if (!dc.dimsKnown()) break;
                ret = new MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros());
                break;
            }
            case REV: {
                if (!dc.dimsKnown()) break;
                ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros());
                break;
            }
            case DIAG: {
                long k = dc.getRows();
                if (k == 1L) {
                    ret = new MatrixCharacteristics(k, k, -1, dc.getNonZeros() >= 0L ? dc.getNonZeros() : k);
                }
                if (k <= 1L) break;
                ret = new MatrixCharacteristics(k, 1L, -1, dc.getNonZeros() >= 0L ? Math.min(k, dc.getNonZeros()) : k);
                break;
            }
            case RESHAPE: {
                if (!dc.dimsKnown()) break;
                if (this.rowsKnown() && this.getDim1() != 0L) {
                    ret = new MatrixCharacteristics(this.getDim1(), dc.getRows() * dc.getCols() / this.getDim1(), -1, dc.getNonZeros());
                    break;
                }
                if (this.colsKnown() && this.getDim2() != 0L) {
                    ret = new MatrixCharacteristics(dc.getRows() * dc.getCols() / this.getDim2(), this.getDim2(), -1, dc.getNonZeros());
                    break;
                }
                if (!this.dimsKnown()) break;
                ret = new MatrixCharacteristics(this.getDim1(), this.getDim2(), -1, -1L);
                break;
            }
            case SORT: {
                boolean unknownIxRet;
                Hop input4 = this.getInput().get(3);
                boolean bl = unknownIxRet = !(input4 instanceof LiteralOp);
                if (!unknownIxRet) {
                    boolean ixret = HopRewriteUtils.getBooleanValueSafe((LiteralOp)input4);
                    long dim2 = ixret ? 1L : dc.getCols();
                    long nnz = ixret ? dc.getRows() : dc.getNonZeros();
                    ret = new MatrixCharacteristics(dc.getRows(), dim2, -1, nnz);
                    break;
                }
                ret = new MatrixCharacteristics(dc.getRows(), -1L, -1, -1L);
            }
        }
        return ret;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected Types.ExecType optFindExecType(boolean transitive) {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() || this.getInput().get(0).isVector() ? Types.ExecType.CP : Types.ExecType.SPARK);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    @Override
    public void refreshSizeInformation() {
        Hop input1 = this.getInput().get(0);
        switch (this._op) {
            case TRANS: {
                this.setDim1(input1.getDim2());
                this.setDim2(input1.getDim1());
                this.setNnz(input1.getNnz());
                break;
            }
            case REV: {
                this.setDim1(input1.getDim1());
                this.setDim2(input1.getDim2());
                this.setNnz(input1.getNnz());
                break;
            }
            case DIAG: {
                long k = input1.getDim1();
                this.setDim1(k);
                if (input1.getDim2() == 1L) {
                    this.setDim2(k);
                    this.setNnz(input1.getNnz() >= 0L ? input1.getNnz() : k);
                }
                if (input1.getDim2() <= 1L) break;
                this.setDim2(1L);
                this.setNnz(input1.getNnz() >= 0L ? Math.min(k, input1.getNnz()) : k);
                break;
            }
            case RESHAPE: {
                if (this._dataType != Types.DataType.TENSOR) {
                    Hop input2 = this.getInput().get(1);
                    Hop input3 = this.getInput().get(2);
                    this.refreshRowsParameterInformation(input2);
                    this.refreshColsParameterInformation(input3);
                    this.setNnz(input1.getNnz());
                    if (this.dimsKnown() || !input1.dimsKnown()) break;
                    if (this.rowsKnown() && this.getDim1() != 0L) {
                        this.setDim2(input1.getLength() / this.getDim1());
                        break;
                    }
                    if (!this.colsKnown() || this.getDim2() == 0L) break;
                    this.setDim1(input1.getLength() / this.getDim2());
                    break;
                }
                this.setNnz(input1.getNnz());
                break;
            }
            case SORT: {
                Hop input4 = this.getInput().get(3);
                boolean unknownIxRet = !(input4 instanceof LiteralOp);
                this.setDim1(input1.getDim1());
                if (!unknownIxRet) {
                    boolean ixret = HopRewriteUtils.getBooleanValueSafe((LiteralOp)input4);
                    this.setDim2(ixret ? 1L : input1.getDim2());
                    this.setNnz(ixret ? input1.getDim1() : input1.getNnz());
                    break;
                }
                this.setDim2(-1L);
                this.setNnz(-1L);
                break;
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        ReorgOp ret = new ReorgOp();
        ret.clone(this, false);
        ret._op = this._op;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof ReorgOp)) {
            return false;
        }
        ReorgOp that2 = (ReorgOp)that;
        boolean bl = ret = this._op == that2._op && this._maxNumThreads == that2._maxNumThreads && this.getInput().size() == that.getInput().size();
        if (ret) {
            for (int i = 0; i < this._input.size(); ++i) {
                ret &= this.getInput().get(i) == that2.getInput().get(i);
            }
        }
        return ret;
    }

    private boolean isSortSPRewriteApplicable() {
        double size;
        boolean ret = false;
        Hop input = this.getInput().get(0);
        double d = size = input.dimsKnown() ? (double)OptimizerUtils.estimateSize(input.getDim1(), 1L) : input.getOutputMemEstimate();
        if (OptimizerUtils.checkSparkBroadcastMemoryBudget(size)) {
            ret = true;
        }
        return ret;
    }
}

