Sfoglia il codice sorgente

孙浩博,fixed:传统算法训练训练

seamew 11 mesi fa
parent
commit
a1c0c9732b

+ 11 - 4
src/main/java/io/renren/common/utils/DockerClientUtils.java

@@ -144,7 +144,7 @@ public class DockerClientUtils {
     public static String execPython(String containerId,String filePath,String fileName) throws DockerException, InterruptedException {
 
         //创建要执行的命令
-        String[] execPython1={"nohup","python","-u","/opt/"+filePath+"/"+fileName,">output.log","2>&1","&"};
+        String[] execPython1={"nohup","python","-u","/opt/"+filePath+"/"+getOriginFileName(fileName),">output.log","2>&1","&"};
         ExecCreation execCreation=docker.execCreate(containerId,execPython1,DockerClient.ExecCreateParam.attachStdout(),
                 DockerClient.ExecCreateParam.attachStderr(), DockerClient.ExecCreateParam.attachStdin(),
                 DockerClient.ExecCreateParam.tty(), DockerClient.ExecCreateParam.detach());
@@ -165,6 +165,11 @@ public class DockerClientUtils {
         return execPythonOutput;
     }
 
+    private static String getOriginFileName(String fileName) {
+        String[] names = fileName.split("/");
+        return names[names.length - 1];
+    }
+
     /**
      * Description 在docker容器中执行运行python文件的命令,并附带参数
      * @param containerId 容器id
@@ -226,9 +231,11 @@ public class DockerClientUtils {
      * @param toFile 复制到的路径
      */
     public static void copyFile(String containerId,String formFile,String toFile) throws DockerException, InterruptedException {
-        String execCp=docker.execCreate(containerId,new String[]{"cp","-r",formFile,toFile}).id();
-        try (LogStream stream=docker.execStart(execCp)){
-            stream.readFully();
+        if (docker.inspectContainer(containerId).state().running()) {
+            String execCp=docker.execCreate(containerId,new String[]{"cp","-r",formFile,toFile}).id();
+            try (LogStream stream=docker.execStart(execCp)){
+                stream.readFully();
+            }
         }
     }
 

+ 38 - 18
src/main/java/io/renren/common/utils/FTPUtils.java

@@ -41,17 +41,17 @@ public class FTPUtils {
         initFTPUtil();
     }
 
-    public void initFTPUtil(){
-        FTPClient ftp=new FTPClient();
+    public static void initFTPUtil() {
+        FTPClient ftp = new FTPClient();
         try {
-            JSch jsch=new JSch();
+            JSch jsch = new JSch();
             System.out.println("是否需要秘钥"+key_needed);
-            if(key_needed){
+            if (key_needed) {
                 jsch.addIdentity(key_location);
             }
-            //获取sshSession  账号-ip-端口
+            // 获取sshSession  账号-ip-端口
             Session sshSession=jsch.getSession(username,host,port);
-            //添加密码
+            // 添加密码
             sshSession.setPassword(password);
             Properties sshConfig=new Properties();
 
@@ -60,26 +60,35 @@ public class FTPUtils {
 
             sshSession.setConfig(sshConfig);
 
-            //开启sshSession连接
+            // 开启sshSession连接
             sshSession.connect();
-            //获取sftp通道
+            // 获取sftp通道
             channel=sshSession.openChannel("sftp");
             channel.connect();
 
-            sftp=(ChannelSftp) channel;
-
+            sftp = (ChannelSftp) channel;
             ftp.enterLocalPassiveMode();
         } catch (JSchException e) {
             e.printStackTrace();
         }
     }
 
+    /**
+     * 检查连接是否建立,如何没有则重新获取链接
+     */
+    private static void checkIsConnect() {
+        if (!sftp.isConnected()) {
+            initFTPUtil();
+        }
+    }
+
     /**
      * Description: 在服务器创建文件夹
      * @param filePath 要创建的文件夹名,创建位置在basePath下
      * @return 成功返回true,否则false
      */
-    public static boolean mkdir(String filePath){
+    public static boolean mkdir(String filePath) {
+        checkIsConnect();
         boolean result=false;
         try {
             sftp.cd("/");
@@ -100,13 +109,14 @@ public class FTPUtils {
      * @param input 输入流
      * @return 成功返回true,否则false
      */
-    public static boolean uploadFile(String filePath,String filename, InputStream input){
+    public static boolean uploadFile(String filePath,String filename, InputStream input) {
+        checkIsConnect();
         boolean result=false;
         try {
             sftp.cd("/");
             sftp.cd(basePath);
             sftp.cd(filePath);
-            sftp.put(input,filename);
+            sftp.put(input,getOriginFileName(filename));
             return true;
         } catch (SftpException e) {
             e.printStackTrace();
@@ -114,14 +124,21 @@ public class FTPUtils {
         return result;
     }
 
+    private static String getOriginFileName(String filename) {
+        checkIsConnect();
+        String[] names = filename.split("/");
+        return names[names.length - 1];
+    }
+
     /**
      * Description: 从FTP服务器下载文件
      * @param filePath FTP服务器文件存放路径。文件的路径为basePath+filePath
      * @return 成功返回true,否则false
      */
-    public static InputStream downloadFile(String filePath){
+    public static InputStream downloadFile(String filePath) {
+        checkIsConnect();
         try {
-            InputStream inputStream=sftp.get(filePath);
+            InputStream inputStream = sftp.get(filePath);
             return inputStream;
         } catch (SftpException e) {
             e.printStackTrace();
@@ -134,7 +151,8 @@ public class FTPUtils {
      * @param filePath
      * @return
      */
-    public static boolean isDirExist(String filePath){
+    public static boolean isDirExist(String filePath) {
+        checkIsConnect();
         boolean isDirExistFlag=false;
         try {
             SftpATTRS sftpATTRS=sftp.lstat(basePath+"/"+filePath);
@@ -153,7 +171,8 @@ public class FTPUtils {
      * @param filePath 要删除的文件夹
      * @return 成功返回true,否则false
      */
-    public static boolean removeDir(String filePath){
+    public static boolean removeDir(String filePath) {
+        checkIsConnect();
         boolean result=false;
         FTPClient ftp=new FTPClient();
         try {
@@ -198,7 +217,8 @@ public class FTPUtils {
      * @param filePath
      * @return
      */
-    public static Vector<ChannelSftp.LsEntry> showFiles(String filePath){
+    public static Vector<ChannelSftp.LsEntry> showFiles(String filePath) {
+        checkIsConnect();
         Vector vector=new Vector();
         try {
             vector=sftp.ls(basePath+"/"+filePath);

+ 95 - 6
src/main/java/io/renren/common/utils/MinIoUtils.java

@@ -8,18 +8,20 @@ import io.minio.http.Method;
 import io.minio.messages.Item;
 import io.renren.modules.sys.entity.algs.FileTest;
 import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.springframework.web.multipart.MultipartFile;
 
-import java.io.IOException;
-import java.io.InputStream;
+import java.io.*;
+import java.nio.charset.StandardCharsets;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
 import java.text.DecimalFormat;
 import java.time.ZoneId;
 import java.time.format.DateTimeFormatter;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
+import java.util.*;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipOutputStream;
 
 /**
  * @Author: Ivan Q
@@ -88,6 +90,57 @@ public class MinIoUtils {
         return objectUrl;
     }
 
+    public static String getFileContent(String bucket,String fileName) {
+        try {
+            InputStream inputStream = minioClient.getObject(bucket, fileName);
+            StringWriter writer = new StringWriter();
+            IOUtils.copy(inputStream, writer, StandardCharsets.UTF_8);
+            return writer.toString();
+        } catch (Exception e) {
+            e.printStackTrace();
+            return null;
+        }
+    }
+
+    /**
+     * Description 下载目标文件夹并打包成ZIP文件
+     *
+     * @param bucket 文件所在桶名称
+     * @param folderName 文件夹名
+     * @return ZIP文件的URL
+     */
+    public static byte[] downloadAndZipFolder(String bucket, String folderName) {
+        // 创建一个内存中的ZIP输出流
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ZipOutputStream zipOutputStream = new ZipOutputStream(outputStream);
+        try {
+            Iterable<Result<Item>> minioObjects = minioClient.listObjects(bucket, folderName);
+            for (Result<Item> result : minioObjects) {
+                Item item = result.get();
+                String fileName = item.objectName();
+                // 创建ZIP条目
+                ZipEntry zipEntry = new ZipEntry(fileName);
+                zipOutputStream.putNextEntry(zipEntry);
+                // 下载文件内容到ZIP输出流
+                InputStream inputStream = minioClient.getObject(bucket, fileName);
+                byte[] buffer = new byte[1024 * 10];
+                BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream, 1024 * 10);
+                int bytesRead;
+                while ((bytesRead = bufferedInputStream.read(buffer)) != -1) {
+                    zipOutputStream.write(buffer, 0, bytesRead);
+                }
+                bufferedInputStream.close();
+                zipOutputStream.closeEntry();
+            }
+            // 关闭ZIP输出流
+            zipOutputStream.close();
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+
+        return outputStream.toByteArray();
+    }
+
     /**
      * Description:获取指定文件的InputStream
      * @param bucketName
@@ -158,7 +211,8 @@ public class MinIoUtils {
         for(Result<Item> result:minioObjects){
             Item item=result.get();
             FileTest fileTest=new FileTest();
-            fileTest.setFilename(item.objectName().substring(prefix.length()+1));
+            // fileTest.setFilename(item.objectName().substring(prefix.length()+1));
+            fileTest.setFilename(item.objectName());
             //设置大小显示
             DecimalFormat df=new DecimalFormat("#.00");
             if (item.size()>=1024*1024*1024){
@@ -181,6 +235,41 @@ public class MinIoUtils {
         return fileList;
     }
 
+    /**
+     * Description 根据桶名和前缀(即文件夹名)查询文件夹
+     * @param bucketName 桶名称
+     * @param prefix 前缀(文件夹名)
+     * @return List<FileTest> 文件夹列表
+     */
+    public static List<FileTest> listFolders(String bucketName, String prefix) throws XmlParserException, NoSuchAlgorithmException, InsufficientDataException, InternalException, InvalidResponseException, InvalidKeyException, InvalidBucketNameException, ErrorResponseException, IOException {
+        Iterable<Result<Item>> minioObjects = minioClient.listObjects(bucketName, prefix, false);
+        Set<FileTest> fileSet = new HashSet<>();
+        for (Result<Item> result : minioObjects) {
+            Item item = result.get();
+            String folder = getFolderFromObjectName(item.objectName(), prefix);
+            if (StringUtils.isNotEmpty(folder)) {
+                FileTest fileTest=new FileTest();
+                fileTest.setFilename(prefix);
+                fileSet.add(fileTest);
+            }
+        }
+        return new ArrayList<>(fileSet);
+    }
+
+    /**
+     * 从对象名称中提取文件夹名
+     * @param objectName 对象名称
+     * @param prefix 前缀(文件夹名)
+     * @return 文件夹名
+     */
+    private static String getFolderFromObjectName(String objectName, String prefix) {
+        int index = objectName.indexOf('/', prefix.length()); // 搜索第一个目录分隔符
+        if (index > 0) {
+            return objectName.substring(0, prefix.length());
+        }
+        return "";
+    }
+
     /**
      * Description 复制文件
      * @param bucketName 存放复制文件的存储桶

+ 2 - 2
src/main/java/io/renren/modules/dataSet/enumeration/DataSetType.java

@@ -10,8 +10,8 @@ package io.renren.modules.dataSet.enumeration;
 public enum DataSetType {
     STATIC_DATASET("dataset", ""),
     DYNAMIC_DATASET("dydataset", ""),
-    FILE_DATASET("filedataset", "单文件"),
-    DIR_DATASET("dirdataset", "文件夹");
+    FILE_DATASET("dataset", "单文件"),
+    DIR_DATASET("dataset", "文件夹");
 
     private final String bucketName;
 

+ 1 - 1
src/main/java/io/renren/modules/sys/controller/VisiWorkflowController.java

@@ -235,7 +235,7 @@ public class VisiWorkflowController extends AbstractController {
     public R Submit(@RequestBody Map<String, Object> params) {
         System.out.println(params.toString());
         System.out.println(params.get("addorupdate"));
-        Boolean addorupdate = (Boolean) params.get("addorupdate");
+        Boolean addorupdate = Boolean.valueOf(params.get("addorupdate").toString());
         System.out.println("//true表示update,新增  false表示add "+ addorupdate);
         Long  workflowId = null;
         if(addorupdate){

+ 32 - 15
src/main/java/io/renren/modules/sys/controller/algs/algTrainController.java

@@ -26,6 +26,7 @@ import org.springframework.transaction.annotation.Transactional;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.multipart.MultipartHttpServletRequest;
 
+import javax.validation.constraints.Min;
 import java.io.*;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
@@ -117,10 +118,15 @@ public class algTrainController {
         algTrain.setUid(Long.parseLong(map.get("uid")));
         algTrain.setAlgorithmId(Long.parseLong(map.get("algorithmId")));
         ValidatorUtils.validateEntity(algTrain, AddGroup.class);
-        List<AlgTrain> list = algTrainService.list();
-        for(AlgTrain train : list){
-            if(train.getMissName().equals(algTrain.getMissName())) return R.error("已经存在重复名称的任务");
+
+        if (algTrainService.selectByMissName(algTrain.getMissName()) != null) {
+            return R.error("已经存在重复名称的任务");
         }
+
+        // List<AlgTrain> list = algTrainService.list();
+        // for(AlgTrain train : list){
+        //     if(train.getMissName().equals(algTrain.getMissName())) return R.error("已经存在重复名称的任务");
+        // }
         Algorithm alg=algsService.getById(algTrain.getAlgorithmId());
         //如果是智能算法
         if(alg.getFrameId()!=-1){
@@ -160,7 +166,7 @@ public class algTrainController {
             List<FileTest> fileList=MinIoUtils.listFiles("algorithm","alg"+map.get("algorithmId"));
             for(FileTest file:fileList){
                 //从minio中获得目标文件的输入流
-                InputStream input= MinIoUtils.getFileInputStream("algorithm","alg"+map.get("algorithmId")+"/"+file.getFilename());
+                InputStream input= MinIoUtils.getFileInputStream("algorithm",file.getFilename());
                 //从minio中获取的文件上传至服务器
                 FTPUtils.uploadFile("algTrain"+algTrain.getAlgorithmTrainingId(),file.getFilename(),input);
             }
@@ -538,7 +544,8 @@ public class algTrainController {
     @GetMapping("/getAlgIdVersionId")
     public R getAlgIdVersionId(String algorithmTrainingId){
         AlgTrain algTrain=algTrainService.selectByPrimaryKey(Long.parseLong(algorithmTrainingId));
-        return R.ok().put("algorithmNameToVersion",algTrain.getAlgorithmId()).put("verisionToFile",algTrain.getVersionId());
+        CategoryEntity category = categoryService.getById(algTrain.getCategoryId());
+        return R.ok().put("algTrain", algTrain).put("category", category);
     }
 
     /**
@@ -577,17 +584,27 @@ public class algTrainController {
      */
     @GetMapping("/listFiles")
     public R listFiles(@RequestParam Map<String,Object> params) throws IOException, InvalidKeyException, NoSuchAlgorithmException, InsufficientDataException, InvalidResponseException, ErrorResponseException, XmlParserException, InvalidBucketNameException, InternalException {
-        String algorithmNameToVersion=String.valueOf(params.get("algorithmNameToVersion"));
-        String verisionToFile=String.valueOf(params.get("verisionToFile"));
-
-        List<FileTest> fileList= MinIoUtils.listFiles("algorithm","alg"+algorithmNameToVersion+"/version"+verisionToFile);
-        List<FileTest> pythonFiles=new ArrayList<>();
-        List<FileTest> datasetFiles=new ArrayList<>();
-        for(FileTest file:fileList){
-            String suffix=file.getFilename().substring(file.getFilename().lastIndexOf("."));
-            if(suffix.equals(".py")){
+        String algorithmNameToVersion = String.valueOf(params.get("algorithmNameToVersion"));
+        String verisionToFile = String.valueOf(params.get("verisionToFile"));
+        long algFrameId = Long.parseLong((String) params.get("algFrameId"));
+        List<FileTest> fileList;
+        if (algFrameId == -1) {
+            fileList = MinIoUtils.listFiles("algorithm", "alg" + algorithmNameToVersion);
+        } else {
+            fileList = MinIoUtils.listFiles("algorithm","alg"+algorithmNameToVersion+"/version"+verisionToFile);
+        }
+        List<FileTest> pythonFiles =new ArrayList<>();
+        List<FileTest> datasetFiles =new ArrayList<>();
+        for (FileTest file : fileList) {
+            String suffix = file.getFilename().substring(file.getFilename().lastIndexOf("."));
+            // if (suffix.equals(".py")){
+            //     pythonFiles.add(file);
+            // } else if (suffix.equals(".csv")) {
+            //     datasetFiles.add(file);
+            // }
+            if (suffix.equals(".py")){
                 pythonFiles.add(file);
-            }else if(suffix.equals(".csv")){
+            } else {
                 datasetFiles.add(file);
             }
         }

+ 53 - 29
src/main/java/io/renren/modules/sys/controller/algs/algsController.java

@@ -9,11 +9,13 @@ import io.renren.common.utils.R;
 import io.renren.common.validator.ValidatorUtils;
 import io.renren.common.validator.group.AddGroup;
 import io.renren.common.validator.group.UpdateGroup;
+import io.renren.modules.dataSet.enumeration.DataSetType;
 import io.renren.modules.sys.entity.algs.*;
 import io.renren.modules.sys.entity.dataset.DataSet;
 import io.renren.modules.sys.service.*;
 import io.renren.modules.sys.service.impl.AlgsModelsServiceImpl;
 import io.renren.modules.sys.service.impl.AlgsServiceImpl;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.shiro.authz.annotation.RequiresPermissions;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.web.bind.annotation.*;
@@ -107,10 +109,13 @@ public class algsController{
         alg.setNumber(Long.parseLong(request.getParameter("number")));
 
         ValidatorUtils.validateEntity(alg, AddGroup.class);
-        List<Algorithm> list = algsService.list();
-        for(Algorithm algorithm : list){
-            if(algorithm.getAlgorithmName().equals(alg.getAlgorithmName())) return R.error("已经存在重复名称的算法");
+        if (algsService.selectByAlgName(alg.getAlgorithmName()) != null) {
+            return R.error("已经存在重复名称的算法");
         }
+        // List<Algorithm> list = algsService.list();
+        // for(Algorithm algorithm : list){
+        //     if(algorithm.getAlgorithmName().equals(alg.getAlgorithmName())) return R.error("已经存在重复名称的算法");
+        // }
 //        user.setCreateUserId(getUserId());
 
         algsService.save(alg);
@@ -193,54 +198,73 @@ public class algsController{
      * @param files 算法文件
      */
     public void saveTraditionalAlg(Algorithm alg,List<MultipartFile> files,String algParameterNameString,String algParameterTypeString,String algResultNameString,String algResultLocationString,String dataSetsString) throws IOException, InvalidKeyException, NoSuchAlgorithmException, InsufficientDataException, InvalidResponseException, ErrorResponseException, XmlParserException, InvalidBucketNameException, InternalException {
-        Long algId=alg.getAlgorithmId();
+        Long algId = alg.getAlgorithmId();
 
-        System.out.println("算法参数名是");
-        String[] algorithmParameterNames=algParameterNameString.split(",");
-        for(String s:algorithmParameterNames){
-            System.out.println(s);
+        String[] algorithmParameterNames = new String[0];
+        if (StringUtils.isNotEmpty(algParameterNameString)) {
+            algorithmParameterNames=algParameterNameString.split(",");
         }
 
-        System.out.println("算法参数类别是");
-        String[] algorithmParameterTypes=algParameterTypeString.split(",");
-        for(String s:algorithmParameterTypes){
-            System.out.println(s);
+        // System.out.println("算法参数名是");
+        // String[] algorithmParameterNames=algParameterNameString.split(",");
+        // for(String s:algorithmParameterNames){
+        //     System.out.println(s);
+        // }
+
+        // System.out.println("算法参数类别是");
+        // String[] algorithmParameterTypes=algParameterTypeString.split(",");
+        // for(String s:algorithmParameterTypes){
+        //     System.out.println(s);
+        // }
+
+        String[] algorithmParameterTypes = new String[0];
+        if (StringUtils.isNotEmpty(algParameterTypeString)) {
+            algorithmParameterTypes = algParameterTypeString.split(",");
         }
-        for(int i=0;i<algorithmParameterNames.length;i++){
-            AlgorithmParameter algorithmParameter=new AlgorithmParameter();
+        for (int i = 0; i < algorithmParameterNames.length; i++) {
+            AlgorithmParameter algorithmParameter = new AlgorithmParameter();
             algorithmParameter.setAlgorithmId(algId);
             algorithmParameter.setAlgorithmParameterName(algorithmParameterNames[i]);
             algorithmParameter.setAlgorithmParameterType(algorithmParameterTypes[i]);
             algorithmParameterService.save(algorithmParameter);
         }
 
-        System.out.println("算法结果名是");
-        String[] algorithmResultNames=algResultNameString.split(",");
-        for(String s:algorithmResultNames){
-            System.out.println(s);
+        String[] algorithmResultNames = new String[0];
+        if (StringUtils.isNotEmpty(algResultNameString)) {
+            algorithmResultNames = algResultNameString.split(",");
         }
-        System.out.println("算法结果文件位置是");
-        String[] algorithmResultLocations=algResultLocationString.split(",");
-        for(String s:algorithmResultLocations){
-            System.out.println(s);
+        // System.out.println("算法结果名是");
+        // String[] algorithmResultNames=algResultNameString.split(",");
+        // for(String s:algorithmResultNames){
+        //     System.out.println(s);
+        // }
+
+        String[] algorithmResultLocations = new String[0];
+        if (StringUtils.isNotEmpty(algResultLocationString)) {
+            algorithmResultLocations = algResultLocationString.split(",");
         }
-        for(int i=0;i<algorithmResultNames.length;i++){
-            AlgorithmResult algorithmResult=new AlgorithmResult();
+        // System.out.println("算法结果文件位置是");
+        // String[] algorithmResultLocations=algResultLocationString.split(",");
+        // for(String s:algorithmResultLocations){
+        //     System.out.println(s);
+        // }
+        for(int i = 0; i < algorithmResultNames.length; i++){
+            AlgorithmResult algorithmResult = new AlgorithmResult();
             algorithmResult.setAlgorithmId(algId);
             algorithmResult.setAlgorithmResultName(algorithmResultNames[i]);
             algorithmResult.setAlgorithmResultLocation(algorithmResultLocations[i]);
             algorithmResultService.save(algorithmResult);
         }
         //将算法文件上传
-        for(MultipartFile file:files){
+        for (MultipartFile file : files) {
             if(file.isEmpty())  break;
             MinIoUtils.uploadMultipartFile(file,"algorithm","alg"+algId+"/"+file.getOriginalFilename());
         }
         //将数据集上传
-        if(dataSetsString.length()!=0){
-            String[] dataSets=dataSetsString.split(",");
-            for(String fileName:dataSets){
-                MinIoUtils.copyFile("algorithm","alg"+algId+"/"+fileName,"dataset",fileName);
+        if (StringUtils.isNotEmpty(dataSetsString)){
+            String[] dataSets = dataSetsString.split(",");
+            for (String fileName : dataSets) {
+                MinIoUtils.copyFile("algorithm","alg"+algId+"/"+fileName, DataSetType.STATIC_DATASET.getBucketName(), fileName);
             }
         }
     }

+ 26 - 11
src/main/java/io/renren/modules/sys/controller/dataset/DataSetController.java

@@ -17,10 +17,12 @@ import io.renren.modules.sys.service.CategoryService;
 import io.renren.modules.sys.service.DataSetService;
 import io.renren.modules.sys.service.impl.AlgsModelsServiceImpl;
 import lombok.SneakyThrows;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.shiro.authz.annotation.RequiresPermissions;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.web.bind.annotation.*;
 
+import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
@@ -164,19 +166,32 @@ public class DataSetController {
      * @return
      */
     @RequestMapping("/downloadDataset")
-    public R downloadStaticDataset(String datasetName){
-        //从minio获得数据集的url,直接返回给前端
-        List<FileTest> dataset = new ArrayList<>();
-        try {
-            dataset = MinIoUtils.listFiles("dataset", datasetName.split("\\.")[0]);
-
-        } catch (Exception e) {
-            e.printStackTrace();
+    public R downloadStaticDataset(@RequestParam String datasetName, @RequestParam(required = false) String categoryName, @RequestParam(required = false) String bucketName){
+        // 从minio获得数据集的url,直接返回给前端
+        // List<FileTest> dataset = new ArrayList<>();
+        // try {
+        //     if (DataSetType.FILE_DATASET.getClassificationName().equals(categoryName)) {
+        //         dataset = MinIoUtils.listFiles(DataSetType.STATIC_DATASET.getBucketName(), datasetName);
+        //     } else if (DataSetType.DIR_DATASET.getClassificationName().equals(categoryName)) {
+        //         dataset = MinIoUtils.listFolders(DataSetType.STATIC_DATASET.getBucketName(), datasetName);
+        //     } else {
+        //         dataset = MinIoUtils.listFiles(DataSetType.STATIC_DATASET.getBucketName(), datasetName.split("\\.")[0]);
+        //     }
+        // } catch (Exception e) {
+        //     return R.error(e.getMessage());
+        // }
+        // if (dataset.isEmpty()){
+        //     return R.error(404,"文件不存在");
+        // }
+        if (StringUtils.isNotEmpty(bucketName)) {
+            return R.ok().put("downloadUrl", MinIoUtils.getFileUrl(bucketName, datasetName));
         }
-        if (dataset.size() ==0){
-            return R.error(404,"文件不存在");
+
+        String downloadUrl;
+        if (DataSetType.DIR_DATASET.getClassificationName().equals(categoryName)) {
+            return R.ok().put("download", MinIoUtils.downloadAndZipFolder(DataSetType.STATIC_DATASET.getBucketName(), datasetName));
         }
-        String downloadUrl= MinIoUtils.getFileUrl("dataset",datasetName);
+        downloadUrl = MinIoUtils.getFileUrl(DataSetType.STATIC_DATASET.getBucketName(), datasetName);
         return R.ok().put("downloadUrl",downloadUrl);
     }
 

+ 5 - 1
src/main/java/io/renren/modules/sys/controller/minIo/MinioController.java

@@ -136,11 +136,15 @@ public class MinioController {
      * @return
      */
     @GetMapping("/readUrlContent")
-    public String readUrlContent(String algorithmNameToVersion,String verisionToFile,String fileName) throws Exception {
+    public String readUrlContent(String algorithmNameToVersion, String verisionToFile, long algFrameId, String fileName) throws Exception {
 
 
         //String requestUrl = minioClient.getPresignedObjectUrl(Method.GET, bucket, fileName, 1000, null);
         //使用minio工具类获取指定文件url
+        if (algFrameId == -1) {
+            return MinIoUtils.getFileContent("algorithm", fileName);
+        }
+
         String requestUrl= MinIoUtils.getFileUrl("algorithm","alg"+algorithmNameToVersion+"/version"+verisionToFile+"/"+fileName);
         HttpURLConnection conn=null;
         BufferedReader br=null;