先来看看源码
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
// 从当前线程中获取map对象
ThreadLocalMap map = getMap(t);
if (map != null)
// 存入map中,下面再详细介绍
map.set(this, value);
else
// 创建并初始化map对象
createMap(t, value);
}
// 尝试获取当前线程的map对象
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
// 创建并初始化map对象并赋值给当前线程
void createMap(Thread t, T firstValue) {
// 下面再详细介绍
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
可以看到set方法是将value对象存入当前线程属性名为threadLocals的键值对对象类型内,所以也能解释为什么可以支持多线程隔离存储,因为不同的线程存的map是不同的
get方法也比较简单,通过获取线程内ThreadLocalMap对象再通过以ThreadLocal实例为key获取对应的value,如果不存在则会通过setInitialValue方法初始化一个值,默认为null
public T get() {
// 获取当前线程
Thread t = Thread.currentThread();
// 从当前线程中获取map对象
ThreadLocalMap map = getMap(t);
if (map != null) {
// 从map中获取键值对,下面再详细介绍
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
// 返回值对象
return result;
}
}
// 如果没有初始化过则初始化value
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();
// 此处源码为直接使用set方法内所有代码,并非直接调用,结果同直接调用一样
set(value);
return value;
}
// SuppliedThreadLocal类有重写该方法,后面解释
protected T initialValue() {
return null;
}
remove方法也很简单,获取线程ThreadLocalMap对象再删除对象内key等于ThreadLocal实例的键值对
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
// 下面再详细介绍
m.remove(this);
}
withInitial方法接收一个Supplier类型函数,返回SuppliedThreadLocal实例
SuppliedThreadLocal内部又重写了initialValue方法返回Supplier.get()值
而initialValue方法只会在调用ThreadLocal类的get方法时通过setInitialValue方法调用,所以SuppliedThreadLocal类相当于拥有初始值的ThreadLocal类,可以直接使用ThreadLocal.get()
// ThreadLocal 类
protected T initialValue() {
return null;
}
//
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}
// 继承ThreadLocal类并提供新的构造方法,支持提供供应者函数实例化
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
@Override
protected T initialValue() {
return supplier.get();
}
}
应用场景,我们知道SimpleDateFormat类是线程不安全的,不能定义为全局变量,定义为局部变量会导致每次重新实例化,不过我们可以使用下面的方式为每个线程仅实例化一次
private ThreadLocal<DateFormat> dateFormat = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
String dateString = dateFormat.get().format(new Date());
如果我们希望创建一个可以被子线程继承ThreadLocal值的对象,则可以使用InheritableThreadLocal类,该类将ThreadLocalMap对象存储至线程的inheritableThreadLocals属性,该属性在创建线程时不为空则会通过ThreadLocal.createInheritedMap方法复制父线程值存入子线程
// 该方法需要被重写
T childValue(T parentValue) {
throw new UnsupportedOperationException();
}
// 内部静态方法,仅用于调用ThreadLocalMap(ThreadLocalMap parentMap)构造方法
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}
// 可继承的本地线程类,重写createMap方法,该类被赋值给线程的inheritableThreadLocals属性上
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
// createInheritedMap方法调用的构造函数内会调用该方法,为了在复制value的时候可以做一些操作
protected T childValue(T parentValue) {
return parentValue;
}
// 重写父类方法,为了调用get、set、remove方法时获取到正确的属性
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
// 重写父类方法,为了调用get、set方法时获取到正确的属性
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}
// Thread类init方法部分代码
public class Thread implements Runnable {
private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc) {
// 获取当前线程,也就是运行创建线程的线程,称为创建后的线程的父线程
Thread parent = currentThread();
if (parent.inheritableThreadLocals != null)
// 如果父线程的 inheritableThreadLocals 属性不为空则子线程会通过该方法复制一份
this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
}
}
代码测试
public static void main(String[] args) throws Exception {
InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();
inheritableThreadLocal.set("子线程会继承");
CountDownLatch countDownLatch = new CountDownLatch(1);
new Thread(() -> {
System.out.println(inheritableThreadLocal.get());
countDownLatch.countDown();
}).start();
countDownLatch.await();
}
那么问题来了,子线程里再创建子线程,子线程的子线程能通过 inheritableThreadLocal.get() 获取到数据吗?
ThreadLocalMap有两个构造函数,一个提供子线程使用,一个提供用户使用,这里可以看到和HashMap不一样的地方是value没有使用链表存储,而是通过nextIndex方法获取下一个可储存的下标,划重点,后面讲
// 初始数组大小,2^n
private static final int INITIAL_CAPACITY = 16;
// 存储数组
private Entry[] table;
// 存储数量
private int size = 0;
// 阈值
private int threshold;
// 计算阈值,hashMap负载因子为0.75,这里为2/3约等于0.67
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
// 获取下一个下标,如果超出则从0开始
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
// 使用父线程ThreadLocalMap实例化子线程对象,该构造函数只能通过**ThreadLocal.createInheritedMap**方法调用
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
// 设置阈值
setThreshold(len);
// 创建新数组
table = new Entry[len];
for (int j = 0; j < len; j++) {
// 获取父线程数据
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
// 该方法默认直接返回e.value,可以通过子类重写做一些事情
Object value = key.childValue(e.value);
// 生成新的Entry
Entry c = new Entry(key, value);
// 计算该ThreadLocal对象新下标,这里并没有重新计算,因为父类计算的时候可能是扩容前的len,所以会与现在计算的不一样
int h = key.threadLocalHashCode & (len - 1);
// 获取到下一个可存入的下标内,这里一定能找到为空的下标,注意nextIndex累加到尾部后会从0开始
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
// 该构造函数通过createMap方法调用
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 初始化长度为16的数组
table = new Entry[INITIAL_CAPACITY];
// 计算下标
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 存入value
table[i] = new Entry(firstKey, firstValue);
// 设置初始数量为1
size = 1;
// 设置扩容阈值
setThreshold(INITIAL_CAPACITY);
}
Entry的key是WeakReference(弱引用)类型的,即如果key除了Entry对象没有其他对象引用时,就会被垃圾回收器回收,回收后referent就会等于null,而referent为null时就会在get/set方法被调用时删除该Enrty
static class ThreadLocalMap {
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
// 使用父类(WeakReference)构造器
super(k);
value = v;
}
}
}
public class WeakReference<T> extends Reference<T> {
// 使用父类(Reference)构造器
public WeakReference(T referent) {
super(referent);
}
}
public abstract class Reference<T> {
// k
private T referent;
Reference(T referent) {
this(referent, null);
}
Reference(T referent, ReferenceQueue<? super T> queue) {
this.referent = referent;
this.queue = (queue == null) ? ReferenceQueue.NULL : queue;
}
public void clear() {
this.referent = null;
}
}
关于 ThreadLoca内存泄漏问题就出在Entry上,我们看一个图
图中 local 和 object 属性我们可以让其等于null,此时key是弱引用所以最终key会等于null,而value是强引用,所以Object对象不会被回收,但是由于在调用set/get方法时会判断key是否等于null来删除整个Entry对象,这样value也会被回收,所以我们推断同1个线程对同1个ThreadLocal对象在被调用set/get前会存在1个value不会被回收问题,我们通过代码验证一下
// 限制最大内存为25m(-Xms25m -Xmx25m)
public static void main(String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(30);
// 每次
ThreadLocal<byte[]> local = new ThreadLocal<>();
for (int i = 0; i < 30; i++) {
executorService.execute(() -> {
// 放入1m大小的对象
local.set(new byte[1024 * 1024]);
System.out.println(Thread.currentThread().getName());
// 不调用remove就会OOM
//local.remove();
});
// 等待上一个线程运行完可以被回收
Thread.sleep(50);
}
}
ThreadLocalMap的set方法包含了所有核心逻辑,可以结合该流程图看源码
ThreadLocalMap.set方法流程图
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 计算存储下标
int i = key.threadLocalHashCode & (len-1);
// 如果数组对应下标已经存在
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 判断该ThreadLocal的key和要set的key是否相等
if (k == key) {
// 如果相等则将值替换为新value
e.value = value;
return;
}
// 如果k为空则表示存在被垃圾回收器回收了的ThreadLocal,则可以直接使用该下标做为存储
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 循环到为空的下标则写入
tab[i] = new Entry(key, value);
int sz = ++size;
// 如果没有删除过失效的数据并且达到了需要扩容的数量
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// 定义下标staleSlot的前后连续不为空的数据第一个过期的元素的下标
int slotToExpunge = staleSlot;
// 找出指定下标往前连续不为空的数据内最远一个存放过数据但key被回收的下标
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
if (e.get() == null) {
slotToExpunge = i;
}
// 因为set方法找到的是第一个过期的元素调用的该方法,所以还要往后再遍历看看是不是有相等的key,不然直接存放的话会导致有两个key被存放在不同的下标
for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 如果找到了则交换前面被删除元素的位置
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 交换之后如果前面没有需要被删除的元素,则记录后面这个需要被删除的元素的下标
if (slotToExpunge == staleSlot) {
slotToExpunge = i;
}
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// 如果遇到被回收的并且上一个循环没有遇到,则记录
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 如果前或后续存在过期的元素则执行清除操作并尝试清理下一个内环过期的元素
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
// 计算前一个位置的下标
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
// 删除指定被回收的下标,并删除后续连续不为空的被回收的数据,对没被回收的数据重新计算下标
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 删除
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
// 遍历后续连续不为空的下标
Entry e;
int i;
for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 如果被回收了就顺便删除
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
// 如果没被回收且下标不是原下标,则再次计算第一个为空的下标并移动
if (h != i) {
tab[i] = null;
while (tab[h] != null) {
h = nextIndex(h, len);
}
tab[h] = e;
}
}
}
// 返回下一个为空的下标
return i;
}
// 尝试性的寻找一下可能过期的数据,返回是否有被删除的数据
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
// 如果存在被回收的数据,则要删除掉,让下标空出来
i = expungeStaleEntry(i);
}
// 无符号左移1,相当于 n = n / 2;
} while ( (n >>>= 1) != 0);
return removed;
}
// 删除所有被回收的数据,如果只有小部分(threshold / 4)被删除,则触发扩容
private void rehash() {
expungeStaleEntries();
if (size >= threshold - threshold / 4)
resize();
}
// 删除所有key被回收的数据
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
// 找到被删除的下标后丢给删除指定下标的方法操作
expungeStaleEntry(j);
}
}
// 两倍扩容
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
// 遍历原数组,过滤被回收的数据,使用新数组长度重新计算下标
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
getEntry方法就很简单了,先计算一下下标,然后判断一下是不是要找的key,如果不是则一直往后找,直到找到或者遇到null为止
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 如果计算出来的下标不是指定的key,说明发送了hash冲突,则往下找
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
// 从指定下标 i 开始往后找到指定 key 的 Entry
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key) {
return e;
}
if (k == null) {
// 删除指定下标的元素
expungeStaleEntry(i);
} else {
i = nextIndex(i, len);
}
e = tab[i];
}
return null;
}
remove 方法也很简单,从指定下标开始找,找到了就丢给 expungeStaleEntry 方法删除指定下标元素
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 从下标i开始一直到下一个为空的下标,如果对应存储的key和要删除的key相等则删除
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// 清除对象引用
e.clear();
// 删除指定下标的元素
expungeStaleEntry(i);
return;
}
}
}
public abstract class Reference<T> {
public void clear() {
this.referent = null;
}
}
// 定义一个integer型原子操作类
private static AtomicInteger nextHashCode = new AtomicInteger();
// 该值可以为计算下标提供散列作用(十六进制)
private static final int HASH_INCREMENT = 0x61c88647;
// 计算下一个hashCode
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
// 初始化当前实例的hashCoe
private final int threadLocalHashCode = nextHashCode();
这里重点讲一下threadLocalHashCode属性,用于计算当前ThreadLocal对象在ThreadLocalMap.table数组的下标,每次实例化ThreadLocal对象属性通过原子类nextHashCode累加,该行为是为了让每个实例的下标尽量散列,因为ThreadLocalMap数据结构是一个环形数组,而当hash冲突后存储方式是找到下一个为空的位置存储,所以让元素尽量散列可以提高查找效率,如下面两种情况当需要再插入一个1时,左边只需要遍历一次就能得到空数组,而右边需要遍历3次,下面是测试代码和输出结果
public static void main(String[] args) {
final int HASH_INCREMENT = 0x61c88647;
int size = 1 << 4;
int hashCode = 0;
for (int i = 0; i < size; i++) {
hashCode += HASH_INCREMENT;
System.out.printf("%32s\t", Integer.toBinaryString(hashCode));
System.out.print(hashCode & size - 1);
System.out.println();
}
}
private int threshold;
private void setThreshold(int len) {
threshold = len * 2 / 3;
}