/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.sql.connect.ml;

import java.io.Serializable;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.spark.connect.proto.Expression;
import org.apache.spark.connect.proto.Fetch;
import org.apache.spark.connect.proto.MlCommand;
import org.apache.spark.connect.proto.MlCommandResult;
import org.apache.spark.connect.proto.MlOperator;
import org.apache.spark.connect.proto.MlParams;
import org.apache.spark.connect.proto.MlRelation;
import org.apache.spark.connect.proto.ObjectRef;
import org.apache.spark.internal.LogEntry;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.clustering.PowerIterationClustering;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.ParamMap$;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.Summary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter$;
import org.apache.spark.sql.connect.ml.AttributeHelper;
import org.apache.spark.sql.connect.ml.AttributeHelper$;
import org.apache.spark.sql.connect.ml.MLCache;
import org.apache.spark.sql.connect.ml.MLCacheInvalidException;
import org.apache.spark.sql.connect.ml.MLUtils$;
import org.apache.spark.sql.connect.ml.MlUnsupportedException;
import org.apache.spark.sql.connect.ml.ModelAttributeHelper;
import org.apache.spark.sql.connect.ml.ModelAttributeHelper$;
import org.apache.spark.sql.connect.ml.Serializer$;
import org.apache.spark.sql.connect.service.SessionHolder;
import org.apache.spark.util.Utils$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnce;
import scala.collection.MapOps;
import scala.collection.immutable.Seq;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.runtime.BooleanRef;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public final class MLHandler$
implements Logging {
    public static final MLHandler$ MODULE$ = new MLHandler$();
    private static final ThreadLocal<SessionHolder> currentSessionHolder;
    private static final scala.collection.immutable.Map<String, Class<?>> allowlistedMLClasses;
    private static final Function1<String, Class<?>> safeMLClassLoader;
    private static transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        Logging.$init$((Logging)MODULE$);
        currentSessionHolder = new ThreadLocal<SessionHolder>(){

            public SessionHolder initialValue() {
                return null;
            }
        };
        scala.collection.immutable.Map<String, Class<?>> transformerClasses = MLUtils$.MODULE$.loadOperators(Transformer.class);
        scala.collection.immutable.Map<String, Class<?>> estimatorClasses = MLUtils$.MODULE$.loadOperators(Estimator.class);
        scala.collection.immutable.Map<String, Class<?>> evaluatorClasses = MLUtils$.MODULE$.loadOperators(Evaluator.class);
        allowlistedMLClasses = (scala.collection.immutable.Map)((MapOps)((MapOps)transformerClasses.$plus$plus(estimatorClasses)).$plus$plus(evaluatorClasses)).$plus$plus((IterableOnce)Predef$.MODULE$.Map().apply((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"org.apache.spark.ml.clustering.PowerIterationClustering"), PowerIterationClustering.class)})));
        safeMLClassLoader = (Function1 & Serializable)className -> {
            SessionHolder sessionHolder = MODULE$.currentSessionHolder().get();
            if (sessionHolder != null) {
                Class clazz;
                String name = MLUtils$.MODULE$.replaceOperator(sessionHolder, (String)className);
                try {
                    clazz = (Class)MODULE$.allowlistedMLClasses().apply((Object)name);
                }
                catch (NoSuchElementException noSuchElementException) {
                    throw new MlUnsupportedException("The class " + className + " to be loaded is not in the allowlist.");
                }
                return clazz;
            }
            return Utils$.MODULE$.classForName(className, Utils$.MODULE$.classForName$default$2(), Utils$.MODULE$.classForName$default$3());
        };
    }

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public Logging.LogStringContext LogStringContext(StringContext sc) {
        return Logging.LogStringContext$((Logging)this, (StringContext)sc);
    }

    public void withLogContext(Map<String, String> context, Function0<BoxedUnit> body) {
        Logging.withLogContext$((Logging)this, context, body);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logInfo(LogEntry entry) {
        Logging.logInfo$((Logging)this, (LogEntry)entry);
    }

    public void logInfo(LogEntry entry, Throwable throwable) {
        Logging.logInfo$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logDebug(LogEntry entry) {
        Logging.logDebug$((Logging)this, (LogEntry)entry);
    }

    public void logDebug(LogEntry entry, Throwable throwable) {
        Logging.logDebug$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logTrace(LogEntry entry) {
        Logging.logTrace$((Logging)this, (LogEntry)entry);
    }

    public void logTrace(LogEntry entry, Throwable throwable) {
        Logging.logTrace$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logWarning(LogEntry entry) {
        Logging.logWarning$((Logging)this, (LogEntry)entry);
    }

    public void logWarning(LogEntry entry, Throwable throwable) {
        Logging.logWarning$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logError(LogEntry entry) {
        Logging.logError$((Logging)this, (LogEntry)entry);
    }

    public void logError(LogEntry entry, Throwable throwable) {
        Logging.logError$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public ThreadLocal<SessionHolder> currentSessionHolder() {
        return currentSessionHolder;
    }

    private scala.collection.immutable.Map<String, Class<?>> allowlistedMLClasses() {
        return allowlistedMLClasses;
    }

    public Function1<String, Class<?>> safeMLClassLoader() {
        return safeMLClassLoader;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public MlCommandResult handleMlCommand(SessionHolder sessionHolder, MlCommand mlCommand) {
        MLCache mlCache = sessionHolder.mlCache();
        this.currentSessionHolder().set(sessionHolder);
        MlCommand.CommandCase commandCase = mlCommand.getCommandCase();
        if (MlCommand.CommandCase.FIT.equals(commandCase)) {
            MlCommand.Fit fitCmd = mlCommand.getFit();
            MlOperator estimatorProto = fitCmd.getEstimator();
            MlOperator.OperatorType operatorType = estimatorProto.getType();
            MlOperator.OperatorType operatorType2 = MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR;
            Predef$.MODULE$.assert(!(operatorType != null ? !operatorType.equals(operatorType2) : operatorType2 != null));
            Dataset<Row> dataset = MLUtils$.MODULE$.parseRelationProto(fitCmd.getDataset(), sessionHolder);
            Estimator<?> estimator = MLUtils$.MODULE$.getEstimator(sessionHolder, estimatorProto, (Option<MlParams>)new Some((Object)fitCmd.getParams()));
            Model model = estimator.fit(dataset);
            String id = mlCache.register(model);
            return MlCommandResult.newBuilder().setOperatorInfo(MlCommandResult.MlOperatorInfo.newBuilder().setObjRef(ObjectRef.newBuilder().setId(id))).build();
        }
        if (MlCommand.CommandCase.FETCH.equals(commandCase)) {
            Object object;
            AttributeHelper helper = AttributeHelper$.MODULE$.apply(sessionHolder, mlCommand.getFetch().getObjRef().getId(), (Fetch.Method[])CollectionConverters$.MODULE$.CollectionHasAsScala(mlCommand.getFetch().getMethodsList()).asScala().toArray(ClassTag$.MODULE$.apply(Fetch.Method.class)));
            Object attrResult = helper.getAttribute();
            Object object2 = attrResult;
            if (object2 instanceof Summary) {
                Summary summary = (Summary)object2;
                String id = mlCache.register(summary);
                return MlCommandResult.newBuilder().setSummary(id).build();
            }
            if (object2 instanceof Model) {
                Model model = (Model)object2;
                String id = mlCache.register(model);
                return MlCommandResult.newBuilder().setOperatorInfo(MlCommandResult.MlOperatorInfo.newBuilder().setObjRef(ObjectRef.newBuilder().setId(id))).build();
            }
            if (ScalaRunTime$.MODULE$.isArray(object2, 1) && ArrayOps$.MODULE$.nonEmpty$extension(Predef$.MODULE$.genericArrayOps(object = object2)) && ArrayOps$.MODULE$.forall$extension(Predef$.MODULE$.genericArrayOps(object), (Function1 & Serializable)x$2 -> BoxesRunTime.boxToBoolean((boolean)MLHandler$.$anonfun$handleMlCommand$1(x$2)))) {
                String[] ids = (String[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.genericArrayOps(object), (Function1 & Serializable)m4 -> mlCache.register((Model)m4), ClassTag$.MODULE$.apply(String.class));
                return MlCommandResult.newBuilder().setOperatorInfo(MlCommandResult.MlOperatorInfo.newBuilder().setObjRef(ObjectRef.newBuilder().setId(Predef$.MODULE$.wrapRefArray((Object[])ids).mkString(",")))).build();
            }
            Expression.Literal param = Serializer$.MODULE$.serializeParam(attrResult);
            return MlCommandResult.newBuilder().setParam(param).build();
        }
        if (MlCommand.CommandCase.DELETE.equals(commandCase)) {
            BooleanRef result = BooleanRef.create((boolean)false);
            ArrayOps$.MODULE$.foreach$extension(Predef$.MODULE$.refArrayOps((Object[])CollectionConverters$.MODULE$.CollectionHasAsScala(mlCommand.getDelete().getObjRefsList()).asScala().toArray(ClassTag$.MODULE$.apply(ObjectRef.class))), (Function1 & Serializable)objId -> {
                MLHandler$.$anonfun$handleMlCommand$3(mlCache, result, objId);
                return BoxedUnit.UNIT;
            });
            return MlCommandResult.newBuilder().setParam(LiteralValueProtoConverter$.MODULE$.toLiteralProto(BoxesRunTime.boxToBoolean((boolean)result.elem))).build();
        }
        if (MlCommand.CommandCase.WRITE.equals(commandCase)) {
            MlCommand.Write.TypeCase typeCase = mlCommand.getWrite().getTypeCase();
            if (MlCommand.Write.TypeCase.OBJ_REF.equals(typeCase)) {
                String objId2 = mlCommand.getWrite().getObjRef().getId();
                Model model = (Model)mlCache.get(objId2);
                if (model == null) {
                    throw new MLCacheInvalidException("model " + objId2);
                }
                Model copiedModel = model.copy(ParamMap$.MODULE$.empty());
                MLUtils$.MODULE$.setInstanceParams((Params)copiedModel, mlCommand.getWrite().getParams());
                Model model2 = copiedModel;
                if (!(model2 instanceof MLWritable)) {
                    throw new MlUnsupportedException(model2 + " is not writable");
                }
                Model model3 = model2;
                MLUtils$.MODULE$.write((MLWritable)model3, mlCommand.getWrite());
                return MlCommandResult.newBuilder().build();
            } else {
                if (!MlCommand.Write.TypeCase.OPERATOR.equals(typeCase)) throw new MlUnsupportedException(typeCase + " write not supported");
                MlCommand.Write writer = mlCommand.getWrite();
                MlOperator.OperatorType operatorType = writer.getOperator().getType();
                String operatorName = writer.getOperator().getName();
                Some params = new Some((Object)writer.getParams());
                MlOperator.OperatorType operatorType3 = operatorType;
                if (MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR.equals(operatorType3)) {
                    Estimator<?> estimator = MLUtils$.MODULE$.getEstimator(sessionHolder, writer.getOperator(), (Option<MlParams>)params);
                    Estimator<?> estimator2 = estimator;
                    if (!(estimator2 instanceof MLWritable)) {
                        throw new MlUnsupportedException("Estimator " + estimator2 + " is not writable");
                    }
                    Estimator<?> estimator3 = estimator2;
                    MLUtils$.MODULE$.write((MLWritable)estimator3, mlCommand.getWrite());
                    return MlCommandResult.newBuilder().build();
                } else if (MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR.equals(operatorType3)) {
                    Evaluator evaluator = MLUtils$.MODULE$.getEvaluator(sessionHolder, writer.getOperator(), (Option<MlParams>)params);
                    Evaluator evaluator2 = evaluator;
                    if (!(evaluator2 instanceof MLWritable)) {
                        throw new MlUnsupportedException("Evaluator " + evaluator2 + " is not writable");
                    }
                    Evaluator evaluator3 = evaluator2;
                    MLUtils$.MODULE$.write((MLWritable)evaluator3, mlCommand.getWrite());
                    return MlCommandResult.newBuilder().build();
                } else {
                    if (!MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER.equals(operatorType3)) throw new MlUnsupportedException("Operator " + operatorName + " is not supported");
                    Transformer transformer = MLUtils$.MODULE$.getTransformer(sessionHolder, writer.getOperator(), (Option<MlParams>)params);
                    Transformer transformer2 = transformer;
                    if (!(transformer2 instanceof MLWritable)) {
                        throw new MlUnsupportedException("Transformer " + transformer2 + " is not writable");
                    }
                    Transformer transformer3 = transformer2;
                    MLUtils$.MODULE$.write((MLWritable)transformer3, mlCommand.getWrite());
                }
            }
            return MlCommandResult.newBuilder().build();
        }
        if (MlCommand.CommandCase.READ.equals(commandCase)) {
            Evaluator evaluator;
            MlOperator operator = mlCommand.getRead().getOperator();
            String name = operator.getName();
            String path = mlCommand.getRead().getPath();
            MlOperator.OperatorType operatorType = operator.getType();
            MlOperator.OperatorType operatorType4 = MlOperator.OperatorType.OPERATOR_TYPE_MODEL;
            if (!(operatorType != null ? !operatorType.equals(operatorType4) : operatorType4 != null)) {
                Transformer model = MLUtils$.MODULE$.loadTransformer(sessionHolder, name, path);
                String id = mlCache.register(model);
                return MlCommandResult.newBuilder().setOperatorInfo(MlCommandResult.MlOperatorInfo.newBuilder().setObjRef(ObjectRef.newBuilder().setId(id)).setUid(model.uid()).setParams(Serializer$.MODULE$.serializeParams((Params)model))).build();
            }
            MlOperator.OperatorType operatorType5 = operator.getType();
            MlOperator.OperatorType operatorType6 = MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR;
            if (!(operatorType5 != null ? !operatorType5.equals(operatorType6) : operatorType6 != null)) {
                evaluator = MLUtils$.MODULE$.loadEstimator(sessionHolder, name, path);
            } else {
                MlOperator.OperatorType operatorType7 = operator.getType();
                MlOperator.OperatorType operatorType8 = MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR;
                if (!(operatorType7 != null ? !operatorType7.equals(operatorType8) : operatorType8 != null)) {
                    evaluator = MLUtils$.MODULE$.loadEvaluator(sessionHolder, name, path);
                } else {
                    MlOperator.OperatorType operatorType9 = operator.getType();
                    MlOperator.OperatorType operatorType10 = MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER;
                    if (operatorType9 != null ? !operatorType9.equals(operatorType10) : operatorType10 != null) throw new MlUnsupportedException(operator.getType() + " read not supported");
                    evaluator = MLUtils$.MODULE$.loadTransformer(sessionHolder, name, path);
                }
            }
            Evaluator mlOperator = evaluator;
            return MlCommandResult.newBuilder().setOperatorInfo(MlCommandResult.MlOperatorInfo.newBuilder().setName(name).setUid(mlOperator.uid()).setParams(Serializer$.MODULE$.serializeParams((Params)mlOperator))).build();
        }
        if (!MlCommand.CommandCase.EVALUATE.equals(commandCase)) throw new MlUnsupportedException(commandCase + " not supported");
        MlCommand.Evaluate evalCmd = mlCommand.getEvaluate();
        MlOperator evalProto = evalCmd.getEvaluator();
        MlOperator.OperatorType operatorType = evalProto.getType();
        MlOperator.OperatorType operatorType11 = MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR;
        Predef$.MODULE$.assert(!(operatorType != null ? !operatorType.equals(operatorType11) : operatorType11 != null));
        Dataset<Row> dataset = MLUtils$.MODULE$.parseRelationProto(evalCmd.getDataset(), sessionHolder);
        Evaluator evaluator = MLUtils$.MODULE$.getEvaluator(sessionHolder, evalProto, (Option<MlParams>)new Some((Object)evalCmd.getParams()));
        double metric = evaluator.evaluate(dataset);
        return MlCommandResult.newBuilder().setParam(LiteralValueProtoConverter$.MODULE$.toLiteralProto(BoxesRunTime.boxToDouble((double)metric))).build();
    }

    public Dataset<Row> transformMLRelation(MlRelation relation, SessionHolder sessionHolder) {
        MlRelation.MlTypeCase mlTypeCase = relation.getMlTypeCase();
        if (MlRelation.MlTypeCase.TRANSFORM.equals(mlTypeCase)) {
            MlRelation.Transform.OperatorCase operatorCase = relation.getTransform().getOperatorCase();
            if (MlRelation.Transform.OperatorCase.TRANSFORMER.equals(operatorCase)) {
                MlRelation.Transform transformProto = relation.getTransform();
                MlOperator.OperatorType operatorType = transformProto.getTransformer().getType();
                MlOperator.OperatorType operatorType2 = MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER;
                Predef$.MODULE$.assert(!(operatorType != null ? !operatorType.equals(operatorType2) : operatorType2 != null));
                Dataset<Row> dataset = MLUtils$.MODULE$.parseRelationProto(transformProto.getInput(), sessionHolder);
                Transformer transformer = MLUtils$.MODULE$.getTransformer(sessionHolder, transformProto);
                return transformer.transform(dataset);
            }
            if (MlRelation.Transform.OperatorCase.OBJ_REF.equals(operatorCase)) {
                ModelAttributeHelper helper = ModelAttributeHelper$.MODULE$.apply(sessionHolder, relation.getTransform().getObjRef().getId(), (Fetch.Method[])Array$.MODULE$.empty(ClassTag$.MODULE$.apply(Fetch.Method.class)));
                return helper.transform(relation.getTransform());
            }
            throw new IllegalArgumentException(operatorCase + " not supported");
        }
        if (MlRelation.MlTypeCase.FETCH.equals(mlTypeCase)) {
            AttributeHelper helper = AttributeHelper$.MODULE$.apply(sessionHolder, relation.getFetch().getObjRef().getId(), (Fetch.Method[])CollectionConverters$.MODULE$.CollectionHasAsScala(relation.getFetch().getMethodsList()).asScala().toArray(ClassTag$.MODULE$.apply(Fetch.Method.class)));
            return (Dataset)helper.getAttribute();
        }
        throw new MlUnsupportedException(mlTypeCase + " not supported");
    }

    public static final /* synthetic */ boolean $anonfun$handleMlCommand$1(Object x$2) {
        return x$2 instanceof Model;
    }

    public static final /* synthetic */ void $anonfun$handleMlCommand$3(MLCache mlCache$1, BooleanRef result$1, ObjectRef objId) {
        if (!objId.getId().contains(".")) {
            mlCache$1.remove(objId.getId());
            result$1.elem = true;
            return;
        }
    }

    private MLHandler$() {
    }
}

