package com.shifthackz.aisdv1.feature.diffusion.ai.unet;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.NNAPIFlags;
import android.graphics.Bitmap;
import android.util.Pair;
import androidx.core.app.NotificationCompat;
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor;
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract;
import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.ArrayExtensionsKt;
import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.TensorExtensionsKt;
import com.shifthackz.aisdv1.feature.diffusion.ai.scheduler.EulerAncestralDiscreteLocalDiffusionScheduler;
import com.shifthackz.aisdv1.feature.diffusion.ai.vae.VaeDecoder;
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag;
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor;
import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider;
import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider;
import com.shifthackz.aisdv1.storage.db.persistent.contract.GenerationResultContract;
import java.nio.IntBuffer;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import kotlin.Metadata;
import kotlin.collections.MapsKt;
import kotlin.jvm.internal.Intrinsics;

/* compiled from: UNet.kt */
@Metadata(d1 = {"\u0000\u0086\u0001\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\b\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\t\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0005\n\u0002\u0010\u0006\n\u0002\b\u0003\n\u0002\u0010\u0011\n\u0002\u0010\u0014\n\u0002\b\u0005\b\u0000\u0018\u00002\u00020\u0001:\u00017B\u001d\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007¢\u0006\u0002\u0010\bJ\u0006\u0010\u0016\u001a\u00020\u0017J,\u0010\u0018\u001a\u000e\u0012\u0004\u0012\u00020\u001a\u0012\u0004\u0012\u00020\u001b0\u00192\u0006\u0010\u001c\u001a\u00020\u001b2\u0006\u0010\u001d\u001a\u00020\u001b2\u0006\u0010\u001e\u001a\u00020\u001bH\u0002J\u0012\u0010\u001f\u001a\u00020 2\n\u0010!\u001a\u0006\u0012\u0002\b\u00030\"J4\u0010#\u001a\u0006\u0012\u0002\b\u00030\"2\u0006\u0010$\u001a\u00020\u00102\u0006\u0010\u000f\u001a\u00020\u00102\u0006\u0010\u0015\u001a\u00020\u00102\u0006\u0010%\u001a\u00020&2\u0006\u0010'\u001a\u00020(H\u0002J>\u0010)\u001a\u00020\u00172\u0006\u0010*\u001a\u00020&2\u0006\u0010+\u001a\u00020\u00102\u0006\u0010,\u001a\u00020\u001b2\u0006\u0010-\u001a\u00020.2\u0006\u0010$\u001a\u00020\u00102\u0006\u0010\u0015\u001a\u00020\u00102\u0006\u0010\u000f\u001a\u00020\u0010J\u0006\u0010/\u001a\u00020\u0017JI\u00100\u001a\u00020\u00172\u0018\u00101\u001a\u0014\u0012\u0010\u0012\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002030202022\u0018\u00104\u001a\u0014\u0012\u0010\u0012\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002030202022\u0006\u0010-\u001a\u00020.H\u0002¢\u0006\u0002\u00105J\u0010\u00106\u001a\u00020\u00172\b\u0010\t\u001a\u0004\u0018\u00010\nR\u0010\u0010\t\u001a\u0004\u0018\u00010\nX\u0082\u000e¢\u0006\u0002\n\u0000R\u0014\u0010\u000b\u001a\u00020\f8BX\u0082\u0004¢\u0006\u0006\u001a\u0004\b\r\u0010\u000eR\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u000e¢\u0006\u0002\n\u0000R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n\u0000R\u0010\u0010\u0013\u001a\u0004\u0018\u00010\u0014X\u0082\u000e¢\u0006\u0002\n\u0000R\u000e\u0010\u0015\u001a\u00020\u0010X\u0082\u000e¢\u0006\u0002\n\u0000¨\u00068"}, d2 = {"Lcom/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet;", "", "deviceNNAPIFlagProvider", "Lcom/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider;", "ortEnvironmentProvider", "Lcom/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider;", "fileProviderDescriptor", "Lcom/shifthackz/aisdv1/core/common/file/FileProviderDescriptor;", "(Lcom/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider;Lcom/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider;Lcom/shifthackz/aisdv1/core/common/file/FileProviderDescriptor;)V", "callback", "Lcom/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet$Callback;", "decoder", "Lcom/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder;", "getDecoder", "()Lcom/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder;", GenerationResultContract.HEIGHT, "", "random", "Ljava/util/Random;", "session", "Lai/onnxruntime/OrtSession;", GenerationResultContract.WIDTH, "close", "", "createUNetModelInput", "", "", "Lai/onnxruntime/OnnxTensor;", "encoderHiddenStates", LocalDiffusionContract.KEY_SAMPLE, "timeStep", "decode", "Landroid/graphics/Bitmap;", "latents", "Lcom/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionTensor;", "generateLatentSample", "batchSize", GenerationResultContract.SEED, "", "initNoiseSigma", "", "inference", "seedNum", "numInferenceSteps", "textEmbeddings", "guidanceScale", "", "initialize", "performGuidance", "noisePrediction", "", "", "noisePredictionText", "([[[[F[[[[FD)V", "setCallback", "Callback", "diffusion_release"}, k = 1, mv = {1, 8, 0}, xi = 48)
/* loaded from: classes2.dex */
public final class UNet {
    private Callback callback;
    private final DeviceNNAPIFlagProvider deviceNNAPIFlagProvider;
    private final FileProviderDescriptor fileProviderDescriptor;
    private int height;
    private final OrtEnvironmentProvider ortEnvironmentProvider;
    private final Random random;
    private OrtSession session;
    private int width;

    /* compiled from: UNet.kt */
    @Metadata(d1 = {"\u0000\u001e\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0004\bf\u0018\u00002\u00020\u0001J\u001a\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00052\b\u0010\u0006\u001a\u0004\u0018\u00010\u0007H&J\u0018\u0010\b\u001a\u00020\u00032\u0006\u0010\t\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020\u0005H&¨\u0006\u000b"}, d2 = {"Lcom/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet$Callback;", "", "onBuildImage", "", NotificationCompat.CATEGORY_STATUS, "", "bitmap", "Landroid/graphics/Bitmap;", "onStep", "maxStep", "step", "diffusion_release"}, k = 1, mv = {1, 8, 0}, xi = 48)
    /* loaded from: classes2.dex */
    public interface Callback {
        void onBuildImage(int status, Bitmap bitmap);

        void onStep(int maxStep, int step);
    }

    public UNet(DeviceNNAPIFlagProvider deviceNNAPIFlagProvider, OrtEnvironmentProvider ortEnvironmentProvider, FileProviderDescriptor fileProviderDescriptor) {
        Intrinsics.checkNotNullParameter(deviceNNAPIFlagProvider, "deviceNNAPIFlagProvider");
        Intrinsics.checkNotNullParameter(ortEnvironmentProvider, "ortEnvironmentProvider");
        Intrinsics.checkNotNullParameter(fileProviderDescriptor, "fileProviderDescriptor");
        this.deviceNNAPIFlagProvider = deviceNNAPIFlagProvider;
        this.ortEnvironmentProvider = ortEnvironmentProvider;
        this.fileProviderDescriptor = fileProviderDescriptor;
        this.random = new Random();
        this.width = 384;
        this.height = 384;
    }

    private final Map<String, OnnxTensor> createUNetModelInput(OnnxTensor encoderHiddenStates, OnnxTensor sample, OnnxTensor timeStep) {
        HashMap hashMap = new HashMap();
        hashMap.put(LocalDiffusionContract.KEY_ENCODER_HIDDEN_STATES, encoderHiddenStates);
        hashMap.put(LocalDiffusionContract.KEY_SAMPLE, sample);
        hashMap.put(LocalDiffusionContract.KEY_TIME_STEP, timeStep);
        return hashMap;
    }

    private final LocalDiffusionTensor<?> generateLatentSample(int batchSize, int height, int width, long seed, float initNoiseSigma) {
        int i;
        Random random = new Random(seed);
        float[][][][] fArr = new float[batchSize][][];
        int i2 = 0;
        while (true) {
            i = 4;
            if (i2 >= batchSize) {
                break;
            }
            float[][][] fArr2 = new float[4][];
            for (int i3 = 0; i3 < 4; i3++) {
                int i4 = height / 8;
                float[][] fArr3 = new float[i4];
                for (int i5 = 0; i5 < i4; i5++) {
                    fArr3[i5] = new float[width / 8];
                }
                fArr2[i3] = fArr3;
            }
            fArr[i2] = fArr2;
            i2++;
        }
        int i6 = 0;
        while (i6 < batchSize) {
            int i7 = 0;
            while (i7 < i) {
                int i8 = height / 8;
                for (int i9 = 0; i9 < i8; i9++) {
                    int i10 = width / 8;
                    int i11 = 0;
                    while (i11 < i10) {
                        int i12 = i7;
                        fArr[i6][i12][i9][i11] = (float) (Math.sqrt((-2.0f) * Math.log(random.nextDouble())) * Math.cos(random.nextDouble() * 6.283185307179586d) * initNoiseSigma);
                        i11++;
                        i7 = i12;
                    }
                }
                i7++;
                i = 4;
            }
            i6++;
            i = 4;
        }
        OnnxTensor createTensor = OnnxTensor.createTensor(this.ortEnvironmentProvider.getEnvironment(), fArr);
        Intrinsics.checkNotNullExpressionValue(createTensor, "createTensor(ortEnvironm…ider.get(), latentsArray)");
        return new LocalDiffusionTensor<>(createTensor, fArr, new long[]{batchSize, 4, height / 8, width / 8});
    }

    private final VaeDecoder getDecoder() {
        return new VaeDecoder(this.ortEnvironmentProvider, this.fileProviderDescriptor, this.deviceNNAPIFlagProvider.get());
    }

    private final void performGuidance(float[][][][] noisePrediction, float[][][][] noisePredictionText, double guidanceScale) {
        long[] sizes = ArrayExtensionsKt.getSizes(noisePrediction);
        long j = sizes[0];
        long j2 = 0;
        while (j2 < j) {
            long j3 = sizes[1];
            long j4 = 0;
            while (j4 < j3) {
                long j5 = sizes[2];
                long j6 = 0;
                while (j6 < j5) {
                    long j7 = sizes[3];
                    long j8 = 0;
                    while (j8 < j7) {
                        int i = (int) j2;
                        long[] jArr = sizes;
                        int i2 = (int) j4;
                        long j9 = j;
                        int i3 = (int) j6;
                        float[] fArr = noisePrediction[i][i2][i3];
                        long j10 = j3;
                        int i4 = (int) j8;
                        float f = fArr[i4];
                        fArr[i4] = f + (((float) guidanceScale) * (noisePredictionText[i][i2][i3][i4] - f));
                        j8++;
                        sizes = jArr;
                        j = j9;
                        j3 = j10;
                        j5 = j5;
                        j2 = j2;
                    }
                    j6++;
                    sizes = sizes;
                    j = j;
                    j5 = j5;
                }
                j4++;
                sizes = sizes;
                j = j;
            }
            j2++;
            sizes = sizes;
            j = j;
        }
    }

    public final void close() {
        OrtSession ortSession = this.session;
        if (ortSession != null) {
            ortSession.close();
        }
        getDecoder().close();
        this.session = null;
    }

    public final Bitmap decode(LocalDiffusionTensor<?> latents) {
        Intrinsics.checkNotNullParameter(latents, "latents");
        float[] array = latents.getTensor().getFloatBuffer().array();
        Intrinsics.checkNotNullExpressionValue(array, "latents.tensor.floatBuffer.array()");
        LocalDiffusionTensor<?> multipleTensorsByFloat = TensorExtensionsKt.multipleTensorsByFloat(array, 5.4899807f, latents.getShape());
        HashMap hashMap = new HashMap();
        hashMap.put(LocalDiffusionContract.KEY_LATENT_SAMPLE, multipleTensorsByFloat.getTensor());
        Object decode = getDecoder().decode(MapsKt.toMap(hashMap));
        VaeDecoder decoder = getDecoder();
        Intrinsics.checkNotNull(decode, "null cannot be cast to non-null type kotlin.Array<kotlin.Array<kotlin.Array<kotlin.FloatArray>>>");
        return decoder.convertToImage((float[][][][]) decode, this.width, this.height);
    }

    public final void inference(long seedNum, int numInferenceSteps, OnnxTensor textEmbeddings, double guidanceScale, int batchSize, int width, int height) {
        Intrinsics.checkNotNullParameter(textEmbeddings, "textEmbeddings");
        this.width = width;
        this.height = height;
        int i = 1;
        EulerAncestralDiscreteLocalDiffusionScheduler eulerAncestralDiscreteLocalDiffusionScheduler = new EulerAncestralDiscreteLocalDiffusionScheduler(null, 1, null);
        int[] timeSteps = eulerAncestralDiscreteLocalDiffusionScheduler.setTimeSteps(numInferenceSteps);
        LocalDiffusionTensor<?> generateLatentSample = generateLatentSample(batchSize, height, width, seedNum <= 0 ? this.random.nextLong() : seedNum, (float) eulerAncestralDiscreteLocalDiffusionScheduler.getInitNoiseSigma());
        long j = height / 8;
        long j2 = width / 8;
        long[] jArr = {2, 4, j, j2};
        int length = timeSteps.length;
        int i2 = 0;
        while (i2 < length) {
            float[] array = generateLatentSample.getTensor().getFloatBuffer().array();
            Intrinsics.checkNotNullExpressionValue(array, "latents.tensor.floatBuffer.array()");
            LocalDiffusionTensor<?> scaleModelInput = eulerAncestralDiscreteLocalDiffusionScheduler.scaleModelInput(TensorExtensionsKt.duplicate(array, jArr), i2);
            Callback callback = this.callback;
            if (callback != null) {
                callback.onStep(timeSteps.length, i2);
            }
            OnnxTensor tensor = scaleModelInput.getTensor();
            OrtEnvironment environment = this.ortEnvironmentProvider.getEnvironment();
            IntBuffer wrap = IntBuffer.wrap(new int[]{timeSteps[i2]});
            long[] jArr2 = jArr;
            long[] jArr3 = new long[i];
            jArr3[0] = 1;
            OnnxTensor createTensor = OnnxTensor.createTensor(environment, wrap, jArr3);
            Intrinsics.checkNotNullExpressionValue(createTensor, "createTensor(\n          …ayOf(1)\n                )");
            Map<String, OnnxTensor> createUNetModelInput = createUNetModelInput(textEmbeddings, tensor, createTensor);
            OrtSession ortSession = this.session;
            Intrinsics.checkNotNull(ortSession);
            OrtSession.Result run = ortSession.run(createUNetModelInput);
            Object value = run.get(0).getValue();
            Intrinsics.checkNotNull(value, "null cannot be cast to non-null type kotlin.Array<kotlin.Array<kotlin.Array<kotlin.FloatArray>>>");
            run.close();
            Pair<float[][][][], float[][][][]> splitTensor = TensorExtensionsKt.splitTensor((float[][][][]) value, new long[]{1, 4, j, j2});
            float[][][][] noisePrediction = (float[][][][]) splitTensor.first;
            float[][][][] noisePredictionText = (float[][][][]) splitTensor.second;
            Intrinsics.checkNotNullExpressionValue(noisePrediction, "noisePrediction");
            Intrinsics.checkNotNullExpressionValue(noisePredictionText, "noisePredictionText");
            performGuidance(noisePrediction, noisePredictionText, guidanceScale);
            OnnxTensor createTensor2 = OnnxTensor.createTensor(this.ortEnvironmentProvider.getEnvironment(), noisePrediction);
            Intrinsics.checkNotNullExpressionValue(createTensor2, "createTensor(\n          …on,\n                    )");
            generateLatentSample = eulerAncestralDiscreteLocalDiffusionScheduler.step(new LocalDiffusionTensor<>(createTensor2, noisePrediction, ArrayExtensionsKt.getSizes(noisePrediction)), i2, generateLatentSample);
            i2++;
            length = length;
            jArr = jArr2;
            i = 1;
        }
        close();
        Callback callback2 = this.callback;
        if (callback2 != null) {
            if (callback2 != null) {
                callback2.onStep(timeSteps.length, timeSteps.length);
            }
            callback2.onBuildImage(0, decode(generateLatentSample));
        }
    }

    public final void initialize() {
        if (this.session != null) {
            return;
        }
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        sessionOptions.addConfigEntry(LocalDiffusionContract.ORT_KEY_MODEL_FORMAT, LocalDiffusionContract.ORT);
        if (this.deviceNNAPIFlagProvider.get() == LocalDiffusionFlag.NN_API.getValue()) {
            sessionOptions.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED));
        }
        this.session = this.ortEnvironmentProvider.getEnvironment().createSession(this.fileProviderDescriptor.getLocalModelDirPath() + "/unet/model.ort", sessionOptions);
    }

    public final void setCallback(Callback callback) {
        this.callback = callback;
    }
}
