package com.tong.servlet; import com.tong.annaotation.*; import com.tong.controller.TongController; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.File; import java.io.IOException; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URL; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; public class DispatcherServlet extends HttpServlet { List classNames = new ArrayList(); Map beans = new HashMap(); Map handlerMap = new HashMap(); // private static final long serialVersionUID = 1L; public void init(ServletConfig config) throws ServletException { //把所有的bean扫描 ----扫描所有的class文件 scanPackage("com.tong"); doInstance(); //根据全类名创建bean doIoc(); //根据bean进行依赖注入 buildUrlMapping(); //建立映射关系 } private void scanPackage(String basePackage) { URL url = this.getClass().getClassLoader().getResource("/"+basePackage.replaceAll("\\.", "/")); String fileStr = url.getFile(); File file = new File(fileStr); String[] filesStr = file.list(); for (String path : filesStr) { File filePath = new File(fileStr+path); if (filePath.isDirectory()) { scanPackage(basePackage+"."+path); }else { //加入list classNames.add(basePackage+"."+filePath.getName()); } } } //根据扫描的list全类名,进行实例化 private void doInstance() { if (classNames.size() <= 0) { System.out.println("包扫描失败....."); return; } //遍历list的所有Class类 for (String className : classNames) { String cn = className.replace(".class", ""); try { Class clazz = Class.forName(cn); if (clazz.isAnnotationPresent(EnjoyController.class)) { Object instance = clazz.newInstance();//创建控制类 EnjoyRequestMapping requestMapping = clazz.getAnnotation(EnjoyRequestMapping.class); String rmvalue = requestMapping.value(); beans.put(rmvalue, instance); }else if (clazz.isAnnotationPresent(EnjoyService.class)) { EnjoyService service = clazz.getAnnotation(EnjoyService.class); Object instance = clazz.newInstance(); beans.put(service.value(), instance); }else { continue; } } catch (ClassNotFoundException e) { e.printStackTrace(); } catch (InstantiationException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } } //把service注入到controller public void doIoc() { if (beans.entrySet().size() <= 0) { System.out.println("没有一个实例化类"); } //把map里所有的实例化遍历出来 for (Map.Entry entry: beans.entrySet()) { Object instance = entry.getValue(); Class clazz = instance.getClass(); if (clazz.isAnnotationPresent(EnjoyController.class)) { Field[] fields = clazz.getDeclaredFields(); for (Field field : fields) { if (field.isAnnotationPresent(EnjoyAutowired.class)) { EnjoyAutowired auto = field.getAnnotation(EnjoyAutowired.class); String key = auto.value(); field.setAccessible(true); try { field.set(instance, beans.get(key)); } catch (IllegalAccessException e) { e.printStackTrace(); } }else { continue; } } }else { continue; } } } private void buildUrlMapping() { if (beans.entrySet().size() <= 0) { System.out.println("没有实例化____"); return; } for (Map.Entry entry: beans.entrySet()) { Object instance = entry.getValue(); Class clazz = instance.getClass(); if (clazz.isAnnotationPresent(EnjoyController.class)) { EnjoyRequestMapping requestMapping = clazz.getAnnotation(EnjoyRequestMapping.class); String classPath = requestMapping.value(); Method[] methods = clazz.getMethods(); for (Method method : methods) { if (method.isAnnotationPresent(EnjoyRequestMapping.class)) { EnjoyRequestMapping methodMapping = method.getAnnotation(EnjoyRequestMapping.class); String methodPath = methodMapping.value(); handlerMap.put(classPath+methodPath, method); }else { continue; } } }else { continue; } } } @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { this.doPost(req, resp); } @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { // 获取请求路径 String uri = req.getRequestURI(); String context = req.getContextPath(); String path = uri.replace(context, ""); Method method = (Method) handlerMap.get(path); TongController instance = (TongController) beans.get("/"+path.split("/")[1]); Object[] args = hand(req, resp, method); try { method.invoke(instance, args); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } } private static Object[] hand(HttpServletRequest request, HttpServletResponse response, Method method) { //拿到待执行的方法有哪些参数 Class[] paramClazzs = method.getParameterTypes(); //根据参数的个数,new 一个参数的数组,将方法的所有参数赋值到args Object[] args = new Object[paramClazzs.length]; int args_i = 0; int index = 0; for (Class paramClazz : paramClazzs) { if (ServletRequest.class.isAssignableFrom(paramClazz)) { args[args_i++] = request; } if (ServletResponse.class.isAssignableFrom(paramClazz)) { args[args_i++] = response; } Annotation[] paramAns = method.getParameterAnnotations()[index]; if (paramAns.length > 0) { for (Annotation paramAn : paramAns) { if (EnjoyRequestParam.class.isAssignableFrom(paramAn.getClass())) { EnjoyRequestParam rp = (EnjoyRequestParam) paramAn; args[args_i++] = request.getParameter(rp.value()); } } } index++; } return args; } }