package cn.axj.crypt.handler;


import cn.axj.crypt.annotation.Decryption;
import cn.axj.crypt.constant.AlgorithmConstant;
import com.alibaba.fastjson.JSONObject;
import cn.axj.crypt.config.CryptProperties;
import cn.axj.crypt.exception.BindingException;
import cn.axj.crypt.exception.DecryptErrorException;
import cn.axj.crypt.exception.EncryptionStringNameIsNullException;
import cn.axj.crypt.exception.ParameterBodyIsNullException;
import cn.axj.crypt.secret.BaseAlgorithmHttpInputMessage;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.MethodParameter;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.stereotype.Controller;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.servlet.mvc.method.annotation.RequestBodyAdvice;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * @author aoxiaojun
 * @date 2020/12/10 17:21
 **/
@ControllerAdvice(annotations = Controller.class)
@Slf4j
public class ParameterDecryptionAdvice implements RequestBodyAdvice {

    @Resource
    private CryptProperties cryptProperties;

    private final HttpServletRequest request;


    private final List<BaseAlgorithmHttpInputMessage> decryptionList = new ArrayList<>(8);

    public ParameterDecryptionAdvice(List<BaseAlgorithmHttpInputMessage> list, HttpServletRequest request) throws BindingException {
        init(list);
        this.request = request;
    }

    private void init(List<BaseAlgorithmHttpInputMessage> list) throws BindingException {
        String temp;
        for(int i=0;i<list.size();i++){
            temp = list.get(i).getBinding();
            for (int j = i+1;j<list.size();j++){
                String binding = list.get(j).getBinding();
                if(temp.equals(binding)){
                    throw new BindingException("The algorithm implements binding value is repeat!please check it exists two value of " + binding);
                }
            }
        }
        this.decryptionList.addAll(list);
        log.info("解密类完成初始化加载........");
    }

    @Override
    public boolean supports(MethodParameter methodParameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) {
        if(!cryptProperties.isParameterDecryption()){
            return false;
        }
        Decryption decryption = methodParameter.getMethodAnnotation(Decryption.class);
        RequestBody requestBody = methodParameter.getParameterAnnotation(RequestBody.class);
        log.info("是否拥有解密注解:"+decryption+"。是否拥有requestBody："+requestBody);
        boolean flag = Objects.nonNull(decryption) && Objects.nonNull(requestBody);
        if (flag){
            log.info("需要进行解密操作........");
        }else {
            log.info("不需要进行解密操作........");
        }
        return flag;
    }

    @Override
    public HttpInputMessage beforeBodyRead(HttpInputMessage inputMessage, MethodParameter parameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) throws RuntimeException, IOException {
        Decryption decryption = parameter.getMethodAnnotation(Decryption.class);
        assert decryption != null;
        String simpleName = parameter.getContainingClass().getSimpleName();
        String methodName = Objects.requireNonNull(parameter.getMethod()).getName();

        String bind = decryption.bind();
        String encryptionStringName = decryption.encryptionStringName();
        boolean required = decryption.required();
        if(StringUtils.isEmpty(encryptionStringName)){
            throw new EncryptionStringNameIsNullException("The annotation Decryption‘s encryptionStringName is empty,please check the method annotation,target method is "+simpleName+"."+methodName);
        }

        InputStream body = inputMessage.getBody();
        if(body == null){
            throw new ParameterBodyIsNullException("request body is null");
        }

        for (BaseAlgorithmHttpInputMessage baseDecryption : decryptionList) {
            if (baseDecryption.getBinding().equals(bind)) {
                String encryptData = IOUtils.toString(body, cryptProperties.getCharset());
                JSONObject jsonObject = JSONObject.parseObject(encryptData);
                String encryptString = jsonObject.getString(encryptionStringName);
                //当必须时，未传加密字段
                if(required&&StringUtils.isEmpty(encryptString)){
                    throw new ParameterBodyIsNullException("request body is null");
                }
                //为必须时，未传加密字段
                if(!required&&StringUtils.isEmpty(encryptString)){
                  return inputMessage;
                }
                //需要加密
                AlgorithmConstant.cacheRequest.put( request.getSession().getId(),true);
                String decrypt;
                try {
                    decrypt = baseDecryption.decrypt(encryptString);
                    log.info("解密后的字符串信息........"+decrypt);
                }catch (Exception e) {
                    e.printStackTrace();
                    throw new DecryptErrorException("decrypt error for String " + encryptString+" ,current algorithm is "+ bind);
                }
                InputStream inputStream = IOUtils.toInputStream(decrypt, cryptProperties.getCharset());
                try {
                    BaseAlgorithmHttpInputMessage baseAlgorithmHttpInputMessage = baseDecryption.getClass().newInstance();
                    baseAlgorithmHttpInputMessage.setHeaders(inputMessage.getHeaders());
                    baseAlgorithmHttpInputMessage.setBody(inputStream);
                    return baseAlgorithmHttpInputMessage;
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        }
        throw new BindingException("Unable to find a binding object of decryption,targetClass is "+simpleName+",method is "+methodName);
    }

    @Override
    public Object afterBodyRead(Object body, HttpInputMessage inputMessage, MethodParameter parameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) {
        return body;
    }

    @Override
    public Object handleEmptyBody(Object body, HttpInputMessage inputMessage, MethodParameter parameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) {
        return body;
    }
}
