/*
 * Decompiled with CFR 0.152.
 */
package org.languagetool.languagemodel.bert;

import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NegotiationType;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
import org.jetbrains.annotations.Nullable;
import org.languagetool.languagemodel.bert.grpc.BertLmGrpc;
import org.languagetool.languagemodel.bert.grpc.BertLmProto;

public class RemoteLanguageModel {
    private final BertLmGrpc.BertLmBlockingStub model;
    private final ManagedChannel channel;

    public RemoteLanguageModel(String host, int port, boolean useSSL, @Nullable String clientPrivateKey, @Nullable String clientCertificate, @Nullable String rootCertificate) throws SSLException {
        this.model = BertLmGrpc.newBlockingStub((Channel)this.getChannel(host, port, useSSL, clientPrivateKey, clientCertificate, rootCertificate));
        this.channel = this.getChannel(host, port, useSSL, clientPrivateKey, clientCertificate, rootCertificate);
    }

    private ManagedChannel getChannel(String host, int port, boolean useSSL, @Nullable String clientPrivateKey, @Nullable String clientCertificate, @Nullable String rootCertificate) throws SSLException {
        NettyChannelBuilder channelBuilder = NettyChannelBuilder.forAddress((String)host, (int)port);
        if (useSSL) {
            SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
            if (rootCertificate != null) {
                sslContextBuilder.trustManager(new File(rootCertificate));
            }
            if (clientCertificate != null && clientPrivateKey != null) {
                sslContextBuilder.keyManager(new File(clientCertificate), new File(clientPrivateKey));
            }
            channelBuilder = channelBuilder.negotiationType(NegotiationType.TLS).sslContext(sslContextBuilder.build());
        } else {
            channelBuilder = channelBuilder.usePlaintext();
        }
        return channelBuilder.build();
    }

    public void shutdown() {
        if (this.channel != null) {
            this.channel.shutdownNow();
        }
    }

    public List<List<Double>> batchScore(List<Request> requests) {
        BertLmProto.BatchScoreRequest batch = BertLmProto.BatchScoreRequest.newBuilder().addAllRequests(requests.stream().map(Request::convert).collect(Collectors.toList())).build();
        return this.model.batchScore(batch).getResponsesList().stream().map(r -> r.getScoresList().get(0).getScoreList()).collect(Collectors.toList());
    }

    public List<Double> score(Request req) {
        return this.model.score(req.convert()).getScoresList().get(0).getScoreList();
    }

    public static class Request {
        public String text;
        public int start;
        public int end;
        public List<String> candidates;

        public Request(String text, int start, int end, List<String> candidates) {
            this.text = text;
            this.start = start;
            this.end = end;
            this.candidates = candidates;
        }

        public BertLmProto.ScoreRequest convert() {
            List<BertLmProto.Mask> masks = Arrays.asList(BertLmProto.Mask.newBuilder().setStart(this.start).setEnd(this.end).addAllCandidates(this.candidates).build());
            return BertLmProto.ScoreRequest.newBuilder().setText(this.text).addAllMask(masks).build();
        }
    }
}

