Skip to content

Commit 04f142d

Browse files
committed
[SPARK-20547][REPL] Throw RemoteClassLoadedError for transient errors in ExecutorClassLoader
## What changes were proposed in this pull request? `ExecutorClassLoader`'s `findClass` may fail to fetch a class due to transient exceptions. For example, when a task is interrupted, if `ExecutorClassLoader` is fetching a class, you may see `InterruptedException` or `IOException` wrapped by `ClassNotFoundException`, even if this class can be loaded. Then the result of `findClass` will be cached by JVM, and later when the same class is being loaded in the same executor, it will just throw NoClassDefFoundError even if the class can be loaded. I found JVM only caches `LinkageError` and `ClassNotFoundException`. Hence in this PR, I changed ExecutorClassLoader to throw `RemoteClassLoadedError` if we cannot get a response from driver. ## How was this patch tested? New unit tests. Closes apache#24683 from zsxwing/SPARK-20547-fix. Authored-by: Shixiong Zhu <[email protected]> Signed-off-by: Shixiong Zhu <[email protected]>
1 parent 4e61de4 commit 04f142d

File tree

5 files changed

+214
-11
lines changed

5 files changed

+214
-11
lines changed

common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

+2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ private void processStreamRequest(final StreamRequest req) {
140140
streamManager.streamSent(req.streamId);
141141
});
142142
} else {
143+
// org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated
144+
// when the following error message is changed.
143145
respond(new StreamFailure(req.streamId, String.format(
144146
"Stream '%s' was not found.", req.streamId)));
145147
}

repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala

+39-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream,
2121
import java.net.{URI, URL, URLEncoder}
2222
import java.nio.channels.Channels
2323

24+
import scala.util.control.NonFatal
25+
2426
import org.apache.hadoop.fs.{FileSystem, Path}
2527
import org.apache.xbean.asm7._
2628
import org.apache.xbean.asm7.Opcodes._
@@ -106,7 +108,17 @@ class ExecutorClassLoader(
106108
parentLoader.loadClass(name)
107109
} catch {
108110
case e: ClassNotFoundException =>
109-
val classOption = findClassLocally(name)
111+
val classOption = try {
112+
findClassLocally(name)
113+
} catch {
114+
case e: RemoteClassLoaderError =>
115+
throw e
116+
case NonFatal(e) =>
117+
// Wrap the error to include the class name
118+
// scalastyle:off throwerror
119+
throw new RemoteClassLoaderError(name, e)
120+
// scalastyle:on throwerror
121+
}
110122
classOption match {
111123
case None => throw new ClassNotFoundException(name, e)
112124
case Some(a) => a
@@ -115,23 +127,31 @@ class ExecutorClassLoader(
115127
}
116128
}
117129

130+
// See org.apache.spark.network.server.TransportRequestHandler.processStreamRequest.
131+
private val STREAM_NOT_FOUND_REGEX = s"Stream '.*' was not found.".r.pattern
132+
118133
private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = {
119-
val channel = env.rpcEnv.openChannel(s"$classUri/$path")
134+
val channel = env.rpcEnv.openChannel(s"$classUri/${urlEncode(path)}")
120135
new FilterInputStream(Channels.newInputStream(channel)) {
121136

122137
override def read(): Int = toClassNotFound(super.read())
123138

124-
override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b))
125-
126139
override def read(b: Array[Byte], offset: Int, len: Int) =
127140
toClassNotFound(super.read(b, offset, len))
128141

129142
private def toClassNotFound(fn: => Int): Int = {
130143
try {
131144
fn
132145
} catch {
133-
case e: Exception =>
146+
case e: RuntimeException if e.getMessage != null
147+
&& STREAM_NOT_FOUND_REGEX.matcher(e.getMessage).matches() =>
148+
// Convert a stream not found error to ClassNotFoundException.
149+
// Driver sends this explicit acknowledgment to tell us that the class was missing.
134150
throw new ClassNotFoundException(path, e)
151+
case NonFatal(e) =>
152+
// scalastyle:off throwerror
153+
throw new RemoteClassLoaderError(path, e)
154+
// scalastyle:on throwerror
135155
}
136156
}
137157
}
@@ -163,7 +183,12 @@ class ExecutorClassLoader(
163183
case e: Exception =>
164184
// Something bad happened while checking if the class exists
165185
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
166-
None
186+
if (userClassPathFirst) {
187+
// Allow to try to load from "parentLoader"
188+
None
189+
} else {
190+
throw e
191+
}
167192
} finally {
168193
if (inputStream != null) {
169194
try {
@@ -237,3 +262,11 @@ extends ClassVisitor(ASM7, cv) {
237262
}
238263
}
239264
}
265+
266+
/**
267+
* An error when we cannot load a class due to exceptions. We don't know if this class exists, so
268+
* throw a special one that's neither [[LinkageError]] nor [[ClassNotFoundException]] to make JVM
269+
* retry to load this class later.
270+
*/
271+
private[repl] class RemoteClassLoaderError(className: String, cause: Throwable)
272+
extends Error(className, cause)

repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala

+141-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.repl
1919

20-
import java.io.File
20+
import java.io.{File, IOException}
21+
import java.lang.reflect.InvocationTargetException
2122
import java.net.{URI, URL, URLClassLoader}
22-
import java.nio.channels.FileChannel
23+
import java.nio.channels.{FileChannel, ReadableByteChannel}
2324
import java.nio.charset.StandardCharsets
2425
import java.nio.file.{Paths, StandardOpenOption}
2526
import java.util
@@ -30,13 +31,15 @@ import scala.io.Source
3031
import scala.language.implicitConversions
3132

3233
import com.google.common.io.Files
33-
import org.mockito.ArgumentMatchers.anyString
34+
import org.mockito.ArgumentMatchers.{any, anyString}
3435
import org.mockito.Mockito._
3536
import org.mockito.invocation.InvocationOnMock
37+
import org.mockito.stubbing.Answer
3638
import org.scalatest.BeforeAndAfterAll
3739
import org.scalatest.mockito.MockitoSugar
3840

3941
import org.apache.spark._
42+
import org.apache.spark.TestUtils.JavaSourceFromString
4043
import org.apache.spark.internal.Logging
4144
import org.apache.spark.rpc.RpcEnv
4245
import org.apache.spark.util.Utils
@@ -193,7 +196,14 @@ class ExecutorClassLoaderSuite
193196
when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => {
194197
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
195198
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
196-
FileChannel.open(path, StandardOpenOption.READ)
199+
if (path.toFile.exists()) {
200+
FileChannel.open(path, StandardOpenOption.READ)
201+
} else {
202+
val channel = mock[ReadableByteChannel]
203+
when(channel.read(any()))
204+
.thenThrow(new RuntimeException(s"Stream '${uri.getPath}' was not found."))
205+
channel
206+
}
197207
})
198208

199209
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
@@ -218,4 +228,131 @@ class ExecutorClassLoaderSuite
218228
}
219229
}
220230

231+
test("nonexistent class and transient errors should cause different errors") {
232+
val conf = new SparkConf()
233+
.setMaster("local")
234+
.setAppName("executor-class-loader-test")
235+
.set("spark.network.timeout", "11s")
236+
.set("spark.repl.class.outputDir", tempDir1.getAbsolutePath)
237+
val sc = new SparkContext(conf)
238+
try {
239+
val replClassUri = sc.conf.get("spark.repl.class.uri")
240+
241+
// Create an RpcEnv for executor
242+
val rpcEnv = RpcEnv.create(
243+
SparkEnv.executorSystemName,
244+
"localhost",
245+
"localhost",
246+
0,
247+
sc.conf,
248+
new SecurityManager(conf), 0, clientMode = true)
249+
250+
try {
251+
val env = mock[SparkEnv]
252+
when(env.rpcEnv).thenReturn(rpcEnv)
253+
254+
val classLoader = new ExecutorClassLoader(
255+
conf,
256+
env,
257+
replClassUri,
258+
getClass().getClassLoader(),
259+
false)
260+
261+
// Test loading a nonexistent class
262+
intercept[java.lang.ClassNotFoundException] {
263+
classLoader.loadClass("NonexistentClass")
264+
}
265+
266+
// Stop SparkContext to simulate transient errors in executors
267+
sc.stop()
268+
269+
val e = intercept[RemoteClassLoaderError] {
270+
classLoader.loadClass("ThisIsAClassName")
271+
}
272+
assert(e.getMessage.contains("ThisIsAClassName"))
273+
// RemoteClassLoaderError must not be LinkageError nor ClassNotFoundException. Otherwise,
274+
// JVM will cache it and doesn't retry to load a class.
275+
assert(!e.isInstanceOf[LinkageError] && !e.isInstanceOf[ClassNotFoundException])
276+
} finally {
277+
rpcEnv.shutdown()
278+
rpcEnv.awaitTermination()
279+
}
280+
} finally {
281+
sc.stop()
282+
}
283+
}
284+
285+
test("SPARK-20547 ExecutorClassLoader should not throw ClassNotFoundException without " +
286+
"acknowledgment from driver") {
287+
val tempDir = Utils.createTempDir()
288+
try {
289+
// Create two classes, "TestClassB" calls "TestClassA", so when calling "TestClassB.foo", JVM
290+
// will try to load "TestClassA".
291+
val sourceCodeOfClassA =
292+
"""public class TestClassA implements java.io.Serializable {
293+
| @Override public String toString() { return "TestClassA"; }
294+
|}""".stripMargin
295+
val sourceFileA = new JavaSourceFromString("TestClassA", sourceCodeOfClassA)
296+
TestUtils.createCompiledClass(
297+
sourceFileA.name, tempDir, sourceFileA, Seq(tempDir.toURI.toURL))
298+
299+
val sourceCodeOfClassB =
300+
"""public class TestClassB implements java.io.Serializable {
301+
| public String foo() { return new TestClassA().toString(); }
302+
| @Override public String toString() { return "TestClassB"; }
303+
|}""".stripMargin
304+
val sourceFileB = new JavaSourceFromString("TestClassB", sourceCodeOfClassB)
305+
TestUtils.createCompiledClass(
306+
sourceFileB.name, tempDir, sourceFileB, Seq(tempDir.toURI.toURL))
307+
308+
val env = mock[SparkEnv]
309+
val rpcEnv = mock[RpcEnv]
310+
when(env.rpcEnv).thenReturn(rpcEnv)
311+
when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() {
312+
private var count = 0
313+
314+
override def answer(invocation: InvocationOnMock): ReadableByteChannel = {
315+
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
316+
val classFileName = uri.getPath().stripPrefix("/")
317+
if (count == 0 && classFileName == "TestClassA.class") {
318+
count += 1
319+
// Let the first attempt to load TestClassA fail with an IOException
320+
val channel = mock[ReadableByteChannel]
321+
when(channel.read(any())).thenThrow(new IOException("broken pipe"))
322+
channel
323+
}
324+
else {
325+
val path = Paths.get(tempDir.getAbsolutePath(), classFileName)
326+
FileChannel.open(path, StandardOpenOption.READ)
327+
}
328+
}
329+
})
330+
331+
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
332+
getClass().getClassLoader(), false)
333+
334+
def callClassBFoo(): String = {
335+
// scalastyle:off classforname
336+
val classB = Class.forName("TestClassB", true, classLoader)
337+
// scalastyle:on classforname
338+
val instanceOfTestClassB = classB.newInstance()
339+
assert(instanceOfTestClassB.toString === "TestClassB")
340+
classB.getMethod("foo").invoke(instanceOfTestClassB).asInstanceOf[String]
341+
}
342+
343+
// Reflection will wrap the exception with InvocationTargetException
344+
val e = intercept[InvocationTargetException] {
345+
callClassBFoo()
346+
}
347+
// "TestClassA" cannot be loaded because of IOException
348+
assert(e.getCause.isInstanceOf[RemoteClassLoaderError])
349+
assert(e.getCause.getCause.isInstanceOf[IOException])
350+
assert(e.getCause.getMessage.contains("TestClassA"))
351+
352+
// We should be able to re-load TestClassA for IOException
353+
assert(callClassBFoo() === "TestClassA")
354+
} finally {
355+
Utils.deleteRecursively(tempDir)
356+
}
357+
}
221358
}

repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala

+16-1
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,27 @@ import java.io._
2222
import scala.tools.nsc.interpreter.SimpleReader
2323

2424
import org.apache.log4j.{Level, LogManager}
25+
import org.scalatest.BeforeAndAfterAll
2526

2627
import org.apache.spark.{SparkContext, SparkFunSuite}
2728
import org.apache.spark.sql.SparkSession
2829
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
2930

30-
class ReplSuite extends SparkFunSuite {
31+
class ReplSuite extends SparkFunSuite with BeforeAndAfterAll {
32+
33+
private var originalClassLoader: ClassLoader = null
34+
35+
override def beforeAll(): Unit = {
36+
originalClassLoader = Thread.currentThread().getContextClassLoader
37+
}
38+
39+
override def afterAll(): Unit = {
40+
if (originalClassLoader != null) {
41+
// Reset the class loader to not affect other suites. REPL will set its own class loader but
42+
// doesn't reset it.
43+
Thread.currentThread().setContextClassLoader(originalClassLoader)
44+
}
45+
}
3146

3247
def runInterpreter(master: String, input: String): String = {
3348
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"

repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala

+16
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,20 @@ class SingletonReplSuite extends SparkFunSuite {
390390
assertDoesNotContain("error:", output)
391391
assertDoesNotContain("Exception", output)
392392
}
393+
394+
test("create encoder in executors") {
395+
val output = runInterpreter(
396+
"""
397+
|case class Foo(s: String)
398+
|
399+
|import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
400+
|
401+
|val r =
402+
| sc.parallelize(1 to 1).map { i => ExpressionEncoder[Foo](); Foo("bar") }.collect.head
403+
""".stripMargin)
404+
405+
assertContains("r: Foo = Foo(bar)", output)
406+
assertDoesNotContain("error:", output)
407+
assertDoesNotContain("Exception", output)
408+
}
393409
}

0 commit comments

Comments
 (0)