/*
 * Decompiled with CFR 0.152.
 */
package io.github.lnyocly.ai4j.utils;

import com.alibaba.fastjson2.JSON;
import io.github.lnyocly.ai4j.annotation.FunctionCall;
import io.github.lnyocly.ai4j.annotation.FunctionParameter;
import io.github.lnyocly.ai4j.annotation.FunctionRequest;
import io.github.lnyocly.ai4j.platform.openai.tool.Tool;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.reflections.Configuration;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;
import org.reflections.scanners.Scanners;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ToolUtil {
    private static final Logger log = LoggerFactory.getLogger(ToolUtil.class);
    static Reflections reflections = new Reflections((Configuration)new ConfigurationBuilder().setUrls(ClasspathHelper.forPackage((String)"", (ClassLoader[])new ClassLoader[0])).setScanners(new Scanner[]{Scanners.TypesAnnotated}));
    public static Map<String, Tool> toolEntityMap = new ConcurrentHashMap<String, Tool>();
    public static Map<String, Class<?>> toolClassMap = new ConcurrentHashMap();
    public static Map<String, Class<?>> toolRequestMap = new ConcurrentHashMap();

    public static String invoke(String functionName, String argument) {
        long currentTimeMillis = System.currentTimeMillis();
        Class<?> functionClass = toolClassMap.get(functionName);
        Class<?> functionRequestClass = toolRequestMap.get(functionName);
        log.info("tool call function {}, argument {}", (Object)functionName, (Object)argument);
        try {
            Method apply = functionClass.getMethod("apply", functionRequestClass);
            Object arg = JSON.parseObject((String)argument, functionRequestClass);
            Object invoke = apply.invoke(functionClass.newInstance(), arg);
            String response = JSON.toJSONString((Object)invoke);
            log.info("response {}, cost {} ms", (Object)response, (Object)(System.currentTimeMillis() - currentTimeMillis));
            return response;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static List<Tool> getAllFunctionTools(List<String> functionList) {
        ArrayList<Tool> tools = new ArrayList<Tool>();
        for (String functionName : functionList) {
            Tool tool = toolEntityMap.get(functionName);
            if (tool == null) {
                tool = ToolUtil.getToolEntity(functionName);
            }
            if (tool == null) continue;
            toolEntityMap.put(functionName, tool);
            tools.add(tool);
        }
        return !tools.isEmpty() ? tools : null;
    }

    public static Tool getToolEntity(String functionName) {
        Tool.Function functionEntity = ToolUtil.getFunctionEntity(functionName);
        if (functionEntity != null) {
            Tool tool = new Tool();
            tool.setType("function");
            tool.setFunction(functionEntity);
            return tool;
        }
        return null;
    }

    public static Tool.Function getFunctionEntity(String functionName) {
        Set functionSet = reflections.getTypesAnnotatedWith(FunctionCall.class);
        for (Class functionClass : functionSet) {
            FunctionCall functionCall = functionClass.getAnnotation(FunctionCall.class);
            String currentFunctionName = functionCall.name();
            if (!currentFunctionName.equals(functionName)) continue;
            Tool.Function function = new Tool.Function();
            function.setName(currentFunctionName);
            function.setDescription(functionCall.description());
            ToolUtil.setFunctionParameters(function, functionClass);
            toolClassMap.put(functionName, functionClass);
            return function;
        }
        return null;
    }

    private static void setFunctionParameters(Tool.Function function, Class<?> functionClass) {
        Class<?>[] classes = functionClass.getDeclaredClasses();
        HashMap<String, Tool.Function.Property> parameters = new HashMap<String, Tool.Function.Property>();
        ArrayList<String> requiredParameters = new ArrayList<String>();
        for (Class<?> clazz : classes) {
            Field[] fields;
            FunctionRequest request = clazz.getAnnotation(FunctionRequest.class);
            if (request == null) continue;
            toolRequestMap.put(function.getName(), clazz);
            for (Field field : fields = clazz.getDeclaredFields()) {
                FunctionParameter parameter = field.getAnnotation(FunctionParameter.class);
                if (parameter == null) continue;
                Class<?> fieldType = field.getType();
                String jsonType = ToolUtil.mapJavaTypeToJsonSchemaType(fieldType);
                Tool.Function.Property property = new Tool.Function.Property();
                property.setType(jsonType);
                property.setDescription(parameter.description());
                if (fieldType.isEnum()) {
                    property.setEnumValues(ToolUtil.getEnumValues(fieldType));
                }
                parameters.put(field.getName(), property);
                if (!parameter.required()) continue;
                requiredParameters.add(field.getName());
            }
        }
        Tool.Function.Parameter parameter = new Tool.Function.Parameter("object", parameters, requiredParameters);
        function.setParameters(parameter);
    }

    private static String mapJavaTypeToJsonSchemaType(Class<?> fieldType) {
        if (fieldType.isEnum()) {
            return "string";
        }
        if (fieldType.equals(String.class)) {
            return "string";
        }
        if (fieldType.equals(Integer.TYPE) || fieldType.equals(Integer.class) || fieldType.equals(Long.TYPE) || fieldType.equals(Long.class) || fieldType.equals(Short.TYPE) || fieldType.equals(Short.class) || fieldType.equals(Float.TYPE) || fieldType.equals(Float.class) || fieldType.equals(Double.TYPE) || fieldType.equals(Double.class)) {
            return "number";
        }
        if (fieldType.equals(Boolean.TYPE) || fieldType.equals(Boolean.class)) {
            return "boolean";
        }
        if (fieldType.isArray()) {
            return "array";
        }
        if (Collection.class.isAssignableFrom(fieldType)) {
            return "array";
        }
        if (Map.class.isAssignableFrom(fieldType)) {
            return "object";
        }
        return "object";
    }

    private static List<String> getEnumValues(Class<?> enumType) {
        ArrayList<String> enumValues = new ArrayList<String>();
        for (Object enumConstant : enumType.getEnumConstants()) {
            enumValues.add(enumConstant.toString());
        }
        return enumValues;
    }
}

