java实现人脸识别 您所在的位置:网站首页 人脸识别java开源框架下载 java实现人脸识别

java实现人脸识别

2024-05-11 21:39| 来源: 网络整理| 查看: 265

人脸识别定义

人脸识别系统主要包括四个组成部分,分别为:人脸图像采集及检测、人脸图像预处理、人脸图像特征提取以及匹配与识别。

本文图片及学习代码来源,更多可前往 实现人脸识别流程

image.png

代码结构

image.png

代码实现 一、构建springboot项目添加pom依赖 5.6.0 0.11.0 1.8 UTF-8 UTF-8 UTF-8 gov.nist.math jama 1.0.3 net.java.dev.jna jna ${jna.version} junit junit 4.12 test org.bytedeco opencv-platform 4.5.1-1.5.5 org.bytedeco javacv 1.5.5 org.bytedeco javacv-platform 1.5.5 com.alibaba fastjson 1.2.62 ai.djl api ${djl.version} ai.djl.pytorch pytorch-model-zoo ${djl.version} ai.djl.pytorch pytorch-engine ${djl.version} ai.djl.pytorch pytorch-native-auto 1.8.1 二、创建转换器 package com.xgc.aideep.face.translator; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.*; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** * @author gc.x * @date 2022-04 */ public class FaceDetectionTranslator implements Translator { private double confThresh; private double nmsThresh; private int topK; private double[] variance; private int[][] scales; private int[] steps; private int width; private int height; public FaceDetectionTranslator( double confThresh, double nmsThresh, double[] variance, int topK, int[][] scales, int[] steps) { this.confThresh = confThresh; this.nmsThresh = nmsThresh; this.variance = variance; this.topK = topK; this.scales = scales; this.steps = steps; } /** * {@inheritDoc} */ @Override public NDList processInput(TranslatorContext ctx, Image input) { width = input.getWidth(); height = input.getHeight(); NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW RGB -> BGR // The network by default takes float32 if (!array.getDataType().equals(DataType.FLOAT32)) { array = array.toType(DataType.FLOAT32, false); } NDArray mean = ctx.getNDManager().create(new float[]{104f, 117f, 123f}, new Shape(3, 1, 1)); array = array.sub(mean); return new NDList(array); } /** * {@inheritDoc} */ @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { NDManager manager = ctx.getNDManager(); double scaleXY = variance[0]; double scaleWH = variance[1]; NDArray prob = list.get(1).get(":, 1:"); prob = NDArrays.stack( new NDList( prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[]{1}))); NDArray boxRecover = boxRecover(manager, width, height, scales, steps); NDArray boundingBoxes = list.get(0); NDArray bbWH = boundingBoxes.get(":, 2:").mul(scaleWH).exp().mul(boxRecover.get(":, 2:")); NDArray bbXY = boundingBoxes .get(":, :2") .mul(scaleXY) .mul(boxRecover.get(":, 2:")) .add(boxRecover.get(":, :2")) .sub(bbWH.mul(0.5f)); boundingBoxes = NDArrays.concat(new NDList(bbXY, bbWH), 1); NDArray landms = list.get(2); landms = decodeLandm(landms, boxRecover, scaleXY); // filter the result below the threshold NDArray cutOff = prob.get(1).gt(confThresh); boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose(); landms = landms.transpose().booleanMask(cutOff, 1).transpose(); prob = prob.booleanMask(cutOff, 1); // start categorical filtering long[] order = prob.get(1).argSort().get(":" + topK).toLongArray(); prob = prob.transpose(); List retNames = new ArrayList(); List retProbs = new ArrayList(); List retBB = new ArrayList(); Map recorder = new ConcurrentHashMap(); for (int i = order.length - 1; i >= 0; i--) { long currMaxLoc = order[i]; float[] classProb = prob.get(currMaxLoc).toFloatArray(); int classId = (int) classProb[0]; double probability = classProb[1]; double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray(); double[] landmsArr = landms.get(currMaxLoc).toDoubleArray(); Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]); List boxes = recorder.getOrDefault(classId, new ArrayList()); boolean belowIoU = true; for (BoundingBox box : boxes) { if (box.getIoU(rect) > nmsThresh) { belowIoU = false; break; } } if (belowIoU) { List keyPoints = new ArrayList(); for (int j = 0; j < 5; j++) { // 5 face landmarks double x = landmsArr[j * 2]; double y = landmsArr[j * 2 + 1]; keyPoints.add(new Point(x * width, y * height)); } Landmark landmark = new Landmark(boxArr[0], boxArr[1], boxArr[2], boxArr[3], keyPoints); boxes.add(landmark); recorder.put(classId, boxes); String className = "Face"; // classes.get(classId) retNames.add(className); retProbs.add(probability); retBB.add(landmark); } } return new DetectedObjects(retNames, retProbs, retBB); } private NDArray boxRecover( NDManager manager, int width, int height, int[][] scales, int[] steps) { int[][] aspectRatio = new int[steps.length][2]; for (int i = 0; i < steps.length; i++) { int wRatio = (int) Math.ceil((float) width / steps[i]); int hRatio = (int) Math.ceil((float) height / steps[i]); aspectRatio[i] = new int[]{hRatio, wRatio}; } List defaultBoxes = new ArrayList(); for (int idx = 0; idx < steps.length; idx++) { int[] scale = scales[idx]; for (int h = 0; h < aspectRatio[idx][0]; h++) { for (int w = 0; w < aspectRatio[idx][1]; w++) { for (int i : scale) { double skx = i * 1.0 / width; double sky = i * 1.0 / height; double cx = (w + 0.5) * steps[idx] / width; double cy = (h + 0.5) * steps[idx] / height; defaultBoxes.add(new double[]{cx, cy, skx, sky}); } } } } double[][] boxes = new double[defaultBoxes.size()][defaultBoxes.get(0).length]; for (int i = 0; i < defaultBoxes.size(); i++) { boxes[i] = defaultBoxes.get(i); } return manager.create(boxes).clip(0.0, 1.0); } // decode face landmarks, 5 points per face private NDArray decodeLandm(NDArray pre, NDArray priors, double scaleXY) { NDArray point1 = pre.get(":, :2").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2")); NDArray point2 = pre.get(":, 2:4").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2")); NDArray point3 = pre.get(":, 4:6").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2")); NDArray point4 = pre.get(":, 6:8").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2")); NDArray point5 = pre.get(":, 8:10").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2")); return NDArrays.concat(new NDList(point1, point2, point3, point4, point5), 1); } /** * {@inheritDoc} */ @Override public Batchifier getBatchifier() { return Batchifier.STACK; } } package com.xgc.aideep.face.translator; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; import ai.djl.translate.Pipeline; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; /** * @author gc.x * @date 2022-04 */ public final class FaceFeatureTranslator implements Translator { public FaceFeatureTranslator() { } @Override public NDList processInput(TranslatorContext ctx, Image input) { NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); Pipeline pipeline = new Pipeline(); pipeline // .add(new Resize(160)) .add(new ToTensor()) .add( new Normalize( new float[]{127.5f / 255.0f, 127.5f / 255.0f, 127.5f / 255.0f}, new float[]{128.0f / 255.0f, 128.0f / 255.0f, 128.0f / 255.0f})); return pipeline.transform(new NDList(array)); } @Override public float[] processOutput(TranslatorContext ctx, NDList list) { NDList result = new NDList(); long numOutputs = list.singletonOrThrow().getShape().get(0); for (int i = 0; i < numOutputs; i++) { result.add(list.singletonOrThrow().get(i)); } float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new); float[] feature = new float[embeddings.length]; for (int i = 0; i < embeddings.length; i++) { feature[i] = embeddings[i][0]; } return feature; } @Override public Batchifier getBatchifier() { return Batchifier.STACK; } } 三、加载模型 package com.xgc.aideep.face.model; import ai.djl.MalformedModelException; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; import com.xgc.aideep.face.translator.FaceDetectionTranslator; import java.io.IOException; /** * @author gc.x * @date 2022-04 */ public final class FaceDetectionModel { private ZooModel model; private static String model_path = "models/ultranet.zip"; public void init(int topK, double confThresh, double nmsThresh) throws MalformedModelException, ModelNotFoundException, IOException { this.model = ModelZoo.loadModel(detectCriteria(topK, confThresh, nmsThresh)); } public ZooModel getModel() { return model; } public void close() { this.model.close(); } private Criteria detectCriteria(int topK, double confThresh, double nmsThresh) { double[] variance = {0.1f, 0.2f}; int[][] scales = {{10, 16, 24}, {32, 48}, {64, 96}, {128, 192, 256}}; int[] steps = {8, 16, 32, 64}; FaceDetectionTranslator translator = new FaceDetectionTranslator(confThresh, nmsThresh, variance, topK, scales, steps); Criteria criteria = Criteria.builder() .setTypes(Image.class, DetectedObjects.class) .optModelUrls(model_path) .optTranslator(translator) .optEngine("PyTorch") // Use PyTorch engine .optProgress(new ProgressBar()) .build(); return criteria; } } package com.xgc.aideep.face.model; import ai.djl.MalformedModelException; import ai.djl.modality.cv.Image; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; import com.xgc.aideep.face.translator.FaceFeatureTranslator; import java.io.IOException; /** * @author gc.x * @date 2022-04 */ public final class FaceFeatureModel { private ZooModel model; private static String model_path = "models/face_feature.zip"; public void init() throws MalformedModelException, ModelNotFoundException, IOException { this.model = ModelZoo.loadModel(detectCriteria()); } public ZooModel getModel() { return model; } public void close() { this.model.close(); } private Criteria detectCriteria() { Criteria criteria = Criteria.builder() .setTypes(Image.class, float[].class) .optModelName("face_feature") // specify model file prefix .optModelUrls(model_path) .optTranslator(new FaceFeatureTranslator()) .optEngine("PyTorch") // Use PyTorch engine .optProgress(new ProgressBar()) .build(); return criteria; } } 四、目标检测 package com.xgc.aideep.face.service.impl; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.BoundingBox; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Point; import ai.djl.modality.cv.output.Rectangle; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; import com.xgc.aideep.face.entity.FaceObject; import com.xgc.aideep.face.model.FaceDetectionModel; import com.xgc.aideep.face.service.FaceDetectService; import com.xgc.aideep.face.service.FaceFeatureService; import com.xgc.aideep.face.util.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacv.Java2DFrameUtils; import org.bytedeco.opencv.opencv_core.Mat; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.awt.image.BufferedImage; import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * 目标检测服务 * * @author gc.x * @date 2022-04 */ @Slf4j @RequiredArgsConstructor @Service public class DetectServiceImpl implements FaceDetectService { @Autowired private FaceFeatureService featureService; @Autowired private FaceDetectionModel faceDetectionModel; public List faceDetect(BufferedImage image) throws IOException, ModelException, TranslateException { ZooModel model = faceDetectionModel.getModel(); try (Predictor predictor = model.newPredictor()) { Image djlImg = ImageFactory.getInstance().fromImage(image); DetectedObjects detections = predictor.predict(djlImg); List list = detections.items(); List faceObjects = new ArrayList(); for (DetectedObjects.DetectedObject detectedObject : list) { BoundingBox box = detectedObject.getBoundingBox(); Rectangle rectangle = box.getBounds(); // 抠人脸图 Rectangle subImageRect = FaceUtil.getSubImageRect( image, rectangle, djlImg.getWidth(), djlImg.getHeight(), 0f); int x = (int) (subImageRect.getX()); int y = (int) (subImageRect.getY()); int w = (int) (subImageRect.getWidth()); int h = (int) (subImageRect.getHeight()); BufferedImage subImage = image.getSubimage(x, y, w, h); Image img = DJLImageUtil.bufferedImage2DJLImage(subImage); //获取特征向量 List feature = featureService.faceFeature(img); FaceObject faceObject = new FaceObject(); faceObject.setFeature(feature); faceObject.setBoundingBox(subImageRect); faceObjects.add(faceObject); } return faceObjects; } } public List faceDetect(Image djlImg) throws TranslateException, ModelException, IOException { ZooModel model = faceDetectionModel.getModel(); try (Predictor predictor = model.newPredictor()) { DetectedObjects detections = predictor.predict(djlImg); List list = detections.items(); BufferedImage image = (BufferedImage) djlImg.getWrappedImage(); List faceObjects = new ArrayList(); for (DetectedObjects.DetectedObject detectedObject : list) { BoundingBox box = detectedObject.getBoundingBox(); Rectangle rectangle = box.getBounds(); // 抠人脸图 // factor = 0.1f, 意思是扩大10%,防止人脸仿射变换后,人脸被部分截掉 Rectangle subImageRect = FaceUtil.getSubImageRect( image, rectangle, djlImg.getWidth(), djlImg.getHeight(), 1.0f); int x = (int) (subImageRect.getX()); int y = (int) (subImageRect.getY()); int w = (int) (subImageRect.getWidth()); int h = (int) (subImageRect.getHeight()); BufferedImage subImage = image.getSubimage(x, y, w, h); // 计算人脸关键点在子图中的新坐标 List points = (List) box.getPath(); double[][] pointsArray = FaceUtil.pointsArray(subImageRect, points); // 转 buffered image 图片格式 BufferedImage converted3BGRsImg = new BufferedImage( subImage.getWidth(), subImage.getHeight(), BufferedImage.TYPE_3BYTE_BGR); converted3BGRsImg.getGraphics().drawImage(subImage, 0, 0, null); Mat mat = Java2DFrameUtils.toMat(converted3BGRsImg); try (NDManager manager = NDManager.newBaseManager()) { NDArray srcPoints = manager.create(pointsArray); NDArray dstPoints = SVDUtil.point112x112(manager); // 定制的5点仿射变换 Mat svdMat = NDArrayUtil.toOpenCVMat(manager, srcPoints, dstPoints); // 换仿射变换矩阵 mat = FaceAlignment.get5WarpAffineImg(mat, svdMat); // mat转bufferedImage类型 BufferedImage mat2BufferedImage = OpenCVImageUtil.mat2BufferedImage(mat); int width = mat2BufferedImage.getWidth() > 112 ? 112 : mat2BufferedImage.getWidth(); int height = mat2BufferedImage.getHeight() > 112 ? 112 : mat2BufferedImage.getHeight(); mat2BufferedImage = mat2BufferedImage.getSubimage(0, 0, width, height); Image img = DJLImageUtil.bufferedImage2DJLImage(mat2BufferedImage); //获取特征向量 List feature = featureService.faceFeature(img); FaceObject faceObject = new FaceObject(); faceObject.setFeature(feature); faceObject.setBoundingBox(subImageRect); faceObject.setScore((float) detectedObject.getProbability()); faceObjects.add(faceObject); } } return faceObjects; } } } 五、特征提取 package com.xgc.aideep.face.service.impl; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; import com.xgc.aideep.face.model.FaceFeatureModel; import com.xgc.aideep.face.service.FaceFeatureService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.List; /** * 特征提取服务 * @author gc.x * @date 2022-04 */ @Slf4j @Service public class FeatureServiceImpl implements FaceFeatureService { @Autowired private FaceFeatureModel faceFeatureModel; public List faceFeature(Image img) throws TranslateException { ZooModel model = faceFeatureModel.getModel(); try (Predictor predictor = model.newPredictor()) { float[] embeddings = predictor.predict(img); List feature = new ArrayList(); if (embeddings != null) { for (int i = 0; i < embeddings.length; i++) { feature.add(embeddings[i]); } } else { return null; } return feature; } } } 六、特征比对(存储)

使用向量引擎milvus存储获取的人脸特征,业务人脸识别

image.png

七、源代码开源地址

https://gitee.com/giteeClass/ai-face



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有