| 
18 | 18 | package org.apache.toree.magic.builtin  | 
19 | 19 | 
 
  | 
20 | 20 | import java.io.{File, PrintStream}  | 
21 |  | -import java.net.URL  | 
 | 21 | +import java.net.{URL, URI}  | 
22 | 22 | import java.nio.file.{Files, Paths}  | 
23 |  | - | 
24 | 23 | import org.apache.toree.magic._  | 
25 | 24 | import org.apache.toree.magic.builtin.AddJar._  | 
26 | 25 | import org.apache.toree.magic.dependencies._  | 
27 | 26 | import org.apache.toree.utils.{ArgumentParsingSupport, DownloadSupport, LogLike, FileUtils}  | 
28 | 27 | import com.typesafe.config.Config  | 
 | 28 | +import org.apache.hadoop.fs.Path  | 
29 | 29 | import org.apache.toree.plugins.annotations.Event  | 
30 | 30 | 
 
  | 
31 | 31 | object AddJar {  | 
 | 32 | +  val HADOOP_FS_SCHEMES = Set("hdfs", "s3", "s3n", "file")  | 
32 | 33 | 
 
  | 
33 | 34 |   private var jarDir:Option[String] = None  | 
34 | 35 | 
 
  | 
@@ -63,18 +64,18 @@ class AddJar  | 
63 | 64 |   private def printStream = new PrintStream(outputStream)  | 
64 | 65 | 
 
  | 
65 | 66 |   /**  | 
66 |  | -   * Retrieves file name from URL.  | 
 | 67 | +   * Retrieves file name from a URI.  | 
67 | 68 |    *  | 
68 |  | -   * @param location The remote location (URL)   | 
69 |  | -   * @return The name of the remote URL, or an empty string if one does not exist  | 
 | 69 | +   * @param location a URI  | 
 | 70 | +   * @return The file name of the remote URI, or an empty string if one does not exist  | 
70 | 71 |    */  | 
71 | 72 |   def getFileFromLocation(location: String): String = {  | 
72 |  | -    val url = new URL(location)  | 
73 |  | -    val file = url.getFile.split("/")  | 
74 |  | -    if (file.length > 0) {  | 
75 |  | -        file.last  | 
 | 73 | +    val uri = new URI(location)  | 
 | 74 | +    val pathParts = uri.getPath.split("/")  | 
 | 75 | +    if (pathParts.nonEmpty) {  | 
 | 76 | +      pathParts.last  | 
76 | 77 |     } else {  | 
77 |  | -        ""  | 
 | 78 | +      ""  | 
78 | 79 |     }  | 
79 | 80 |   }  | 
80 | 81 | 
 
  | 
@@ -122,10 +123,27 @@ class AddJar  | 
122 | 123 |       // Report beginning of download  | 
123 | 124 |       printStream.println(s"Starting download from $jarRemoteLocation")  | 
124 | 125 | 
 
  | 
125 |  | -      downloadFile(  | 
126 |  | -        new URL(jarRemoteLocation),  | 
127 |  | -        new File(downloadLocation).toURI.toURL  | 
128 |  | -      )  | 
 | 126 | +      val jar = URI.create(jarRemoteLocation)  | 
 | 127 | +      if (HADOOP_FS_SCHEMES.contains(jar.getScheme)) {  | 
 | 128 | +        val conf = kernel.sparkContext.hadoopConfiguration  | 
 | 129 | +        val jarPath = new Path(jarRemoteLocation)  | 
 | 130 | +        val fs = jarPath.getFileSystem(conf)  | 
 | 131 | +        val destPath = if (downloadLocation.startsWith("file:")) {  | 
 | 132 | +          new Path(downloadLocation)  | 
 | 133 | +        } else {  | 
 | 134 | +          new Path("file:" + downloadLocation)  | 
 | 135 | +        }  | 
 | 136 | + | 
 | 137 | +        fs.copyToLocalFile(  | 
 | 138 | +          false /* keep original file */,  | 
 | 139 | +          jarPath, destPath,  | 
 | 140 | +          true /* don't create checksum files */)  | 
 | 141 | +      } else {  | 
 | 142 | +        downloadFile(  | 
 | 143 | +          new URL(jarRemoteLocation),  | 
 | 144 | +          new File(downloadLocation).toURI.toURL  | 
 | 145 | +        )  | 
 | 146 | +      }  | 
129 | 147 | 
 
  | 
130 | 148 |       // Report download finished  | 
131 | 149 |       printStream.println(s"Finished download of $jarName")  | 
 | 
0 commit comments