• 跨线程池共享的ThreadLocal


    背景

    在实际开发中,我们经常会用线程池处理大量任务,但是线程池的使用会让线程变量ThreadLocal无法访问,会很不爽.
    举栗,当我们想提高性能,用线程池同时调用多个服务,又不想修改原本代码,实现无侵入的特性,就会很有用.

    原理

    我们希望使ThreadLocal线程变量跨线程共享,这就要打破jdk提供的访问限制.
    ThreadLocal的线程隔离是通过在每个线程内部维护一个ThreadLocalMap的映射表,每次获取都是从当前线程或者父线程的map中(对于InheritableThreadLocal)取值,从而实现的线程间变量访问的隔离.

    // ThreadLocal 的部分源码
    // 获取线程的ThreadLocalMap 
    ThreadLocalMap getMap(Thread t) {
      return t.threadLocals;
    }
    
    // 先获取线程的ThreadLocalMap,再往对应的map中设置值
    public void set(T value) {
    	Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    为此,我们可以通过维护一个静态变量,记录下当前线程所使用的需要跨线程共享的ThreadLocal表,然后再创建线程运行上下文复制线程变量,等线程运行时再其前后以需要的线程变量替换,运行完之后再还原.

    // 用该结构包围实际运行的方法
    public void run() {
    	Map<MyThreadLocal<Object>, Object> replace = null;
        try {
            replace = replace();
            // 设置上下文
            runnable.run();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            // 还原上下文
            restore(replace);
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    测试代码

    import com.alibaba.ttl.TransmittableThreadLocal;
    import com.alibaba.ttl.TtlRunnable;
    
    import java.util.concurrent.CountDownLatch;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    
    /**
     * @author Lion Zhou
     * @date 2022/9/15
     */
    public class Test {
        public static void test5() {
            ExecutorService executorService = Executors.newFixedThreadPool(1);
            // 用一个空任务让线程池创建好线程
            executorService.submit(() -> {
            });
    
            // 使用我们定义好的线程变量
            MyThreadLocal<Integer> mtl1 = new MyThreadLocal<>();
            mtl1.set(333);
    
            executorService.submit(MyThreadLocalContext.go(() -> {
                System.out.print("1:");
                System.out.println(mtl1.get());
                // 修改线程变量,因为是副本,不影响其他线程中的值
                mtl1.set(111);
            }));
    
            executorService.submit(() -> {
                System.out.print("2:");
                // 正常使用为 null
                System.out.println(mtl1.get());
            });
    
            executorService.submit(MyThreadLocalContext.go(() -> {
                System.out.print("3:");
                // 还是 333
                System.out.println(mtl1.get());
            }));
    
            executorService.shutdown();
            System.out.println("end:" + mtl1.get());
        }
    
        public static void main(String[] args) {
            test5();
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    源码

    import java.util.WeakHashMap;
    
    /**
     * @author Lion Zhou
     * @date 2022/9/15
     */
    public class MyThreadLocal<T> extends InheritableThreadLocal<T> {
    
        // 维护每个线程所持有的 MyThreadLocal 为后续跨线程传递使用
        static InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = new InheritableThreadLocal<>();
    
        @Override
        public T get() {
            // 直接调用原本的 get 方法
            T t = super.get();
            if (null == t && null != holder.get()) {
                // 对应key的值已经不存在了,删除当前的持有数据
                holder.get().remove(this);
            }
            return t;
        }
    
        @Override
        public void set(T value) {
            super.set(value);
            if (holder.get() == null) {
                holder.set(new WeakHashMap<>(8));
            }
            holder.get().put((MyThreadLocal<Object>) this, null);
        }
    
        @Override
        public void remove() {
            super.remove();
            if (holder.get() != null) {
                holder.get().remove(this);
            }
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    import java.util.Collections;
    import java.util.Map;
    import java.util.Optional;
    import java.util.WeakHashMap;
    
    /**
     * @author Lion Zhou
     * @date 2022/9/15
     */
    public class MyThreadLocalContext {
    
        public static Runnable go(Runnable runnable) {
            InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = MyThreadLocal.holder;
    
            Map<MyThreadLocal<Object>, Object> map = Collections.emptyMap();
            if (null != holder.get()) {
                map = new WeakHashMap<>(holder.get().size());
    //            System.out.println("start");
                for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.get().entrySet()) {
    //                System.out.println(entry.getKey().get());
                    map.put(entry.getKey(), entry.getKey().get());
                }
    //            System.out.println("end");
            }
            return new Context(map, runnable);
        }
    
        public static class Context implements Runnable {
            Map<MyThreadLocal<Object>, Object> holder;
            Runnable runnable;
    
            public Context(Map<MyThreadLocal<Object>, Object> holder, Runnable runnable) {
                this.holder = holder;
                this.runnable = runnable;
            }
    
            public Map<MyThreadLocal<Object>, Object> replace() {
                // 保留原本的线程本地变量
                Map<MyThreadLocal<Object>, Object> replace = new WeakHashMap<>();
    
                // 将复制过来的值重新赋值给当前上下文环境
    //            System.out.println("context start");
                // 上下文切换
                for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
    //                System.out.println(String.format("old: %s, new: %s", Optional.ofNullable(entry.getKey().get()).orElse("null").toString(),
    //                        entry.getValue()));
    
                    // 保存 线程本地变量 的现场
                    replace.put(entry.getKey(), entry.getKey().get());
                    // 替换需要的上下文
                    entry.getKey().set(entry.getValue());
                }
    //            System.out.println("context end");
                return replace;
            }
    
            public void restore(Map<MyThreadLocal<Object>, Object> restore) {
                if (null == restore) {
                    return;
                }
                for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
                    // 原本的值
                    Object old = restore.get(entry.getKey());
                    if (null == old) {
                        // 原本就为null
                        entry.getKey().remove();
                    } else {
                        entry.getKey().set(old);
                    }
                }
            }
    
            @Override
            public void run() {
                Map<MyThreadLocal<Object>, Object> replace = null;
                try {
                    replace = replace();
                    // 设置上下文
                    runnable.run();
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    // 还原上下文
                    restore(replace);
                }
            }
        }
    
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90

    总结

    源码名字不好听,见谅.
    代码很简单,只是为了演示,实际还存在一些问题,比如在替换上下文时没有使用 deepcopy等.
    WeakHashMap使用就是基本问题了,因为线程变量是跨线程的,并非线程独有值,因此不能破坏原本变量的生命周期(由此导致内存泄露),所以要用弱引用.

    相关资料

    • 阿里开源的线程间上下文传递解决方案 支持编程和java agent的形式
  • 相关阅读:
    网络安全(黑客)自学
    SaaSBase:什么是汇思?
    马斯克:虽然我是Rust的粉丝,但我选择C,其次是C++和Python
    CSDN里的常用网址(2)
    Linux目录权限修改-2
    ESP8266-Arduino编程实例-SPIFFS及数据上传(Arduino IDE和PlatformIO IDE)
    【Vue官方教程】Vue官方教程之我的笔记--20221026
    Linux的screen工具库实现多终端
    十五. 实战——mysql建库建表 字符集 和 排序规则
    融云:让银行轻松上“云”
  • 原文地址:https://blog.csdn.net/weixin_46080554/article/details/126872872