在实际开发中,我们经常会用线程池处理大量任务,但是线程池的使用会让线程变量ThreadLocal
无法访问,会很不爽.
举栗,当我们想提高性能,用线程池同时调用多个服务,又不想修改原本代码,实现无侵入的特性,就会很有用.
我们希望使ThreadLocal
线程变量跨线程共享,这就要打破jdk提供的访问限制.
ThreadLocal
的线程隔离是通过在每个线程内部维护一个ThreadLocalMap
的映射表,每次获取都是从当前线程或者父线程的map中(对于InheritableThreadLocal
)取值,从而实现的线程间变量访问的隔离.
// ThreadLocal 的部分源码
// 获取线程的ThreadLocalMap
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
// 先获取线程的ThreadLocalMap,再往对应的map中设置值
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
为此,我们可以通过维护一个静态变量,记录下当前线程所使用的需要跨线程共享的ThreadLocal
表,然后再创建线程运行上下文复制线程变量,等线程运行时再其前后以需要的线程变量替换,运行完之后再还原.
// 用该结构包围实际运行的方法
public void run() {
Map<MyThreadLocal<Object>, Object> replace = null;
try {
replace = replace();
// 设置上下文
runnable.run();
} catch (Exception e) {
e.printStackTrace();
} finally {
// 还原上下文
restore(replace);
}
}
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* @author Lion Zhou
* @date 2022/9/15
*/
public class Test {
public static void test5() {
ExecutorService executorService = Executors.newFixedThreadPool(1);
// 用一个空任务让线程池创建好线程
executorService.submit(() -> {
});
// 使用我们定义好的线程变量
MyThreadLocal<Integer> mtl1 = new MyThreadLocal<>();
mtl1.set(333);
executorService.submit(MyThreadLocalContext.go(() -> {
System.out.print("1:");
System.out.println(mtl1.get());
// 修改线程变量,因为是副本,不影响其他线程中的值
mtl1.set(111);
}));
executorService.submit(() -> {
System.out.print("2:");
// 正常使用为 null
System.out.println(mtl1.get());
});
executorService.submit(MyThreadLocalContext.go(() -> {
System.out.print("3:");
// 还是 333
System.out.println(mtl1.get());
}));
executorService.shutdown();
System.out.println("end:" + mtl1.get());
}
public static void main(String[] args) {
test5();
}
}
import java.util.WeakHashMap;
/**
* @author Lion Zhou
* @date 2022/9/15
*/
public class MyThreadLocal<T> extends InheritableThreadLocal<T> {
// 维护每个线程所持有的 MyThreadLocal 为后续跨线程传递使用
static InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = new InheritableThreadLocal<>();
@Override
public T get() {
// 直接调用原本的 get 方法
T t = super.get();
if (null == t && null != holder.get()) {
// 对应key的值已经不存在了,删除当前的持有数据
holder.get().remove(this);
}
return t;
}
@Override
public void set(T value) {
super.set(value);
if (holder.get() == null) {
holder.set(new WeakHashMap<>(8));
}
holder.get().put((MyThreadLocal<Object>) this, null);
}
@Override
public void remove() {
super.remove();
if (holder.get() != null) {
holder.get().remove(this);
}
}
}
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.WeakHashMap;
/**
* @author Lion Zhou
* @date 2022/9/15
*/
public class MyThreadLocalContext {
public static Runnable go(Runnable runnable) {
InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = MyThreadLocal.holder;
Map<MyThreadLocal<Object>, Object> map = Collections.emptyMap();
if (null != holder.get()) {
map = new WeakHashMap<>(holder.get().size());
// System.out.println("start");
for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.get().entrySet()) {
// System.out.println(entry.getKey().get());
map.put(entry.getKey(), entry.getKey().get());
}
// System.out.println("end");
}
return new Context(map, runnable);
}
public static class Context implements Runnable {
Map<MyThreadLocal<Object>, Object> holder;
Runnable runnable;
public Context(Map<MyThreadLocal<Object>, Object> holder, Runnable runnable) {
this.holder = holder;
this.runnable = runnable;
}
public Map<MyThreadLocal<Object>, Object> replace() {
// 保留原本的线程本地变量
Map<MyThreadLocal<Object>, Object> replace = new WeakHashMap<>();
// 将复制过来的值重新赋值给当前上下文环境
// System.out.println("context start");
// 上下文切换
for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
// System.out.println(String.format("old: %s, new: %s", Optional.ofNullable(entry.getKey().get()).orElse("null").toString(),
// entry.getValue()));
// 保存 线程本地变量 的现场
replace.put(entry.getKey(), entry.getKey().get());
// 替换需要的上下文
entry.getKey().set(entry.getValue());
}
// System.out.println("context end");
return replace;
}
public void restore(Map<MyThreadLocal<Object>, Object> restore) {
if (null == restore) {
return;
}
for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
// 原本的值
Object old = restore.get(entry.getKey());
if (null == old) {
// 原本就为null
entry.getKey().remove();
} else {
entry.getKey().set(old);
}
}
}
@Override
public void run() {
Map<MyThreadLocal<Object>, Object> replace = null;
try {
replace = replace();
// 设置上下文
runnable.run();
} catch (Exception e) {
e.printStackTrace();
} finally {
// 还原上下文
restore(replace);
}
}
}
}
源码名字不好听,见谅.
代码很简单,只是为了演示,实际还存在一些问题,比如在替换上下文时没有使用 deepcopy等.
WeakHashMap
使用就是基本问题了,因为线程变量是跨线程的,并非线程独有值,因此不能破坏原本变量的生命周期(由此导致内存泄露),所以要用弱引用.