diff --git a/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java b/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java index 3e2abb866c..032a250608 100644 --- a/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java +++ b/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java @@ -50,6 +50,8 @@ public class ThreadLocalFury extends AbstractThreadSafeFury { private Consumer factoryCallback; private final Map allFury; + private ClassLoader classLoader; + public ThreadLocalFury(Function furyFactory) { factoryCallback = f -> {}; allFury = Collections.synchronizedMap(new WeakHashMap<>()); @@ -58,7 +60,11 @@ public ThreadLocalFury(Function furyFactory) { () -> { LoaderBinding binding = new LoaderBinding(furyFactory); binding.setBindingCallback(factoryCallback); - binding.setClassLoader(Thread.currentThread().getContextClassLoader()); + ClassLoader cl = + classLoader == null + ? Thread.currentThread().getContextClassLoader() + : classLoader; + binding.setClassLoader(cl); allFury.put(binding, null); return binding; }); @@ -258,6 +264,7 @@ public void setClassLoader(ClassLoader classLoader) { @Override public void setClassLoader(ClassLoader classLoader, StagingType stagingType) { + this.classLoader = classLoader; bindingThreadLocal.get().setClassLoader(classLoader, stagingType); } diff --git a/java/fury-core/src/main/java/org/apache/fury/pool/FuryPooledObjectFactory.java b/java/fury-core/src/main/java/org/apache/fury/pool/FuryPooledObjectFactory.java index 17e56a6348..9f8e2a2840 100644 --- a/java/fury-core/src/main/java/org/apache/fury/pool/FuryPooledObjectFactory.java +++ b/java/fury-core/src/main/java/org/apache/fury/pool/FuryPooledObjectFactory.java @@ -48,10 +48,15 @@ public class FuryPooledObjectFactory { */ final Cache classLoaderFuryPooledCache; + private volatile ClassLoader classLoader = null; + /** ThreadLocal: ClassLoader. */ private final ThreadLocal classLoaderLocal = ThreadLocal.withInitial( () -> { + if (classLoader != null) { + return classLoader; + } ClassLoader loader = Thread.currentThread().getContextClassLoader(); if (loader == null) { loader = Fury.class.getClassLoader(); @@ -111,6 +116,7 @@ public void setClassLoader(ClassLoader classLoader, LoaderBinding.StagingType st // may be used to clear some classloader classLoader = Fury.class.getClassLoader(); } + this.classLoader = classLoader; classLoaderLocal.set(classLoader); getOrAddCache(classLoader); } diff --git a/java/fury-core/src/test/java/org/apache/fury/classloader/ThreadSafeFuryClassLoaderTest.java b/java/fury-core/src/test/java/org/apache/fury/classloader/ThreadSafeFuryClassLoaderTest.java new file mode 100644 index 0000000000..f01ff02464 --- /dev/null +++ b/java/fury-core/src/test/java/org/apache/fury/classloader/ThreadSafeFuryClassLoaderTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fury.classloader; + +import org.apache.fury.Fury; +import org.apache.fury.ThreadSafeFury; +import org.apache.fury.config.Language; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ThreadSafeFuryClassLoaderTest { + + static class MyClassLoader extends ClassLoader {} + + @Test + void testFuryThreadLocalUseProvidedClassLoader() throws InterruptedException { + final MyClassLoader myClassLoader = new MyClassLoader(); + final ThreadSafeFury fury = + Fury.builder() + .withClassLoader(myClassLoader) + .withLanguage(Language.JAVA) + .requireClassRegistration(false) + .buildThreadLocalFury(); + fury.setClassLoader(myClassLoader); + + Thread thread = + new Thread( + () -> { + final ClassLoader t = fury.getClassLoader(); + Assert.assertEquals(t, myClassLoader); + }); + thread.start(); + thread.join(); + } + + @Test + void testFuryPoolUseProvidedClassLoader() throws InterruptedException { + final MyClassLoader myClassLoader = new MyClassLoader(); + final ThreadSafeFury fury = + Fury.builder() + .withClassLoader(myClassLoader) + .withLanguage(Language.JAVA) + .requireClassRegistration(false) + .buildThreadSafeFuryPool(1, 1); + fury.setClassLoader(myClassLoader); + + Thread thread = + new Thread( + () -> { + final ClassLoader t = fury.getClassLoader(); + Assert.assertEquals(t, myClassLoader); + }); + thread.start(); + thread.join(); + } +}