/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.mrmr;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.features.weighting.AbstractWeighting;
import com.rapidminer.operator.mrmr.MRMRFunctions;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeInt;
import java.util.List;

public class CorrelationBasedWeakAssociations
extends AbstractWeighting {
    public static final String PARAMETER_ROUNDS = "rounds";

    public CorrelationBasedWeakAssociations(OperatorDescription description) {
        super(description);
    }

    public AttributeWeights calculateWeights(ExampleSet exampleSet) throws OperatorException {
        AttributeWeights result = new AttributeWeights(exampleSet);
        Attributes attributes = exampleSet.getAttributes();
        int p = attributes.size();
        Attribute label = attributes.getLabel();
        double[] irv = new double[p];
        double[][] crv = new double[p][p];
        for (Attribute a : attributes) {
            double temp = MRMRFunctions.GetSimilarity(exampleSet, a, label);
            int idx = a.getTableIndex();
            irv[idx] = temp * temp;
            for (Attribute b : attributes) {
                crv[idx][b.getTableIndex()] = MRMRFunctions.GetSimilarity(exampleSet, a, b);
            }
        }
        int rounds = this.getParameterAsInt(PARAMETER_ROUNDS);
        for (int i = 0; i < rounds; ++i) {
            crv = this.matMult(crv, crv);
        }
        irv = this.matMult(crv, irv);
        for (Attribute att : attributes) {
            result.setWeight(att.getName(), irv[att.getTableIndex()]);
        }
        return result;
    }

    private double[][] matMult(double[][] A, double[][] B) {
        int wA = A[0].length;
        int wB = B[0].length;
        int hA = A.length;
        int hB = B.length;
        double[][] result = new double[hA][wB];
        if (wA != hB) {
            this.log("Matrix multiplication failed because of dimensionality mismatch.");
        } else {
            for (int i = 0; i < hA; ++i) {
                for (int j = 0; j < wB; ++j) {
                    for (int k = 0; k < wA; ++k) {
                        double[] dArray = result[i];
                        int n = j;
                        dArray[n] = dArray[n] + A[i][k] * B[k][j];
                    }
                }
            }
        }
        return result;
    }

    private double[] matMult(double[][] A, double[] B) {
        int wA = A[0].length;
        int hA = A.length;
        int hB = B.length;
        double[] result = new double[hA];
        if (wA != hB) {
            this.log("Matrix multiplication failed because of dimensionality mismatch.");
        } else {
            for (int i = 0; i < hA; ++i) {
                for (int k = 0; k < wA; ++k) {
                    int n = i;
                    result[n] = result[n] + A[i][k] * B[k];
                }
            }
        }
        return result;
    }

    public List<ParameterType> getParameterTypes() {
        List types = super.getParameterTypes();
        types.add(new ParameterTypeInt(PARAMETER_ROUNDS, "Number of rounds to calculat CRV*CRV", 0, Integer.MAX_VALUE));
        return types;
    }

    public boolean supportsCapability(OperatorCapability capability) {
        switch (capability) {
            case BINOMINAL_LABEL: 
            case POLYNOMINAL_LABEL: 
            case NUMERICAL_LABEL: 
            case BINOMINAL_ATTRIBUTES: 
            case POLYNOMINAL_ATTRIBUTES: 
            case NUMERICAL_ATTRIBUTES: {
                return true;
            }
        }
        return false;
    }
}

