ThreadLocal


一、简介

  多线程访问同一个共享变量的时候容易出现并发问题,特别是多个线程对一个变量进行写入的时候,为了保证线程安全,一般使用者在访问共享变量的时候需要进行额外的同步措施才能保证线程安全性。ThreadLocal是除了加锁这种同步方式之外的一种保证一种规避多线程访问出现线程不安全的方法,当我们在创建一个变量后,如果每个线程对其进行访问的时候访问的都是线程自己的变量这样就不会存在线程不安全问题。

  ThreadLocal是JDK包提供的,它提供线程本地变量,如果创建一个ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的一个副本,在实际多线程操作的时候,操作的是自己本地内存中的变量,从而规避了线程安全问题。

​ 变量是同一个,但是每个线程都使用同一个初始值,也就是使用同一个变量的一个新的副本。这种情况之下ThreadLocal就非常实用,比如说DAO的数据库连接,我们知道DAO是单例的,那么他的属性Connection就不是一个线程安全的变量。而我们每个线程都需要使用他,并且各自使用各自的。这种情况,ThreadLocal就比较好的解决了这个问题。

特点总结

  • 每个线程内都有自己的实例副本,且该副本只能由当前线程调用
  • 线程不能访问其他线程的实例副本,就不会存在多线程之间的共享变量问题
  • 统一设置初始值,但每个线程对这个值的修改都是独立的

二、API

1、基本方法

ThreadLocal 对外暴露的非静态方法只有有 get()set()remove() 三个

// 返回当前线程的此线程局部变量的副本中的值。 
public T get();

// 将当前线程的此线程局部变量的副本设置为指定的值。
public void set(T value);

// 删除此线程局部变量的当前线程的值。删除之后再次调用为初始值,若未设置初始值则为null
public void remove();

用例

public static ThreadLocal<Integer> threadLocal = new ThreadLocal<>();

@Test
public void test3(){

    System.out.println(threadLocal.get());

    threadLocal.set(1);
    System.out.println(threadLocal.get());

    threadLocal.remove();
}

2、初始化值-旧

ThreadLocal 有一个 protected 修饰的方法,此方法不推荐使用了

protected T initialValue() {
    return null;
}

返回此线程局部变量的当前线程的“初始值”。该方法将在第一次使用get()方法访问变量时被调用,如果线程先调用set(T),在调用 get()initialValue方法将不会被调用

通常情况下,这种方法最多每个线程调用一次,但**如果线程调用remove() 后在调用 get(),则会再次执行 initialValue()**,因为remove()ThreadLocal 将回到未初始化状态 。

这个实现默认返回null ; 如果希望线程局部变量具有除null之外的初始值,则必须写一个类继承自ThreadLocal,并重写该方法。 通常,将使用匿名内部类。

下面是阿里巴巴开发手册中对 ThreadLocal 初始化相关的案例

【强制】SimpleDateFormat 是线程不安全的类,一般不要定义为 static 变量,如果定义为 static,必须加锁,或者使用 DateUtils 工具类。

正例:注意线程安全,使用 DateUtils。亦推荐如下处理:

private static final ThreadLocal<DateFormat> df = new ThreadLocal<DateFormat>() {       @Override        
    protected DateFormat initialValue() {            
        return new SimpleDateFormat("yyyy-MM-dd");        
    }    
};    

说明:如果是 JDK8 的应用,可以使用 Instant 代替 DateLocalDateTime 代替 CalendarDateTimeFormatter 代替 SimpleDateFormat,官方给出的解释:simple beautiful strong immutable thread-safe

【参考】ThreadLocal 无法解决共享对象的更新问题,ThreadLocal 对象建议使用 static 修饰。这个变量是针对一个线程内所有操作共享的,所以设置为静态变量,所有此类实例共享 此静态变量 ,也就是说在类第一次被使用时装载,只分配一块存储空间,所有此类的对象(只 要是这个线程内定义的)都可以操控这个变量。

3、初始化值-新

该方法是 jdk1.8 之后新增的

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
}

使用方法如下

public static ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 0);

与上面旧版的初始化方式的区别是改方式创建的 ThreadLocal 对象其实是 ThreadLocal 的一个静态内部类 SuppliedThreadLocal

/**
* 该静态内部类继承自 ThreadLocal,并且使用供给型接口定义了初始值,并重写了旧版的初始化方法
*/
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

    private final Supplier<? extends T> supplier;

    SuppliedThreadLocal(Supplier<? extends T> supplier) {
        this.supplier = Objects.requireNonNull(supplier);
    }

    @Override
    protected T initialValue() {
        return supplier.get();
    }
}

三、源码

1、ThreadLocalMap

首先需要知道 ThreadThreadLocalThreadLocalMap 三者之间的关系:

Thread 类中有两个变量 threadLocalsinheritableThreadLocals ,二者都是 ThreadLocal内部类 ThreadLocalMap 类型的变量,我们通过查看内部类 ThreadLocalMap 可以发现实际上它类似于一个HashMap。在默认情况下,每个线程中的这两个变量都为null,只有当线程第一次调用 ThreadLocalset() 或者 get() 方法的时候才会创建他们。除此之外,每个线程的本地变量不是存放在 ThreadLocal 实例中,而是放在调用线程的 ThreadLocals 变量里面。也就是说,ThreadLocal 类型的本地变量是存放在具体的线程空间上,其本身相当于一个装载本地变量的工具壳,通过 set()方法将 value 添加到调用线程的 threadLocals 中,当调用线程调用 get() 方法时候能够从它的 threadLocals 中取出变量。如果调用线程一直不终止,那么这个本地变量将会一直存放在他的 threadLocals 中,所以不使用本地变量的时候需要调用 remove()方法将 threadLocals 中删除不用的本地变量。

image-20220813183804852

image-20220814140956628

ThreadLocal是线程本地变量,ThreadLocalMapThreadLocal 的静态内部类,ThreadLocalMap 类似 HashMap 拥有 key-value 的结构,同时也是 Thread 的一个成员变量。

  • ThreadLocalMapThread 的成员变量
public class Thread implements Runnable {    
    ...

    ThreadLocal.ThreadLocalMap threadLocals = null;

    ...
}
  • ThreadLocalMapThreadLocal 的静态内部类
public class ThreadLocal<T> {
    ...

    static class ThreadLocalMap {

        static class Entry extends WeakReference<ThreadLocal<?>> {
            // 与这个ThreadLocal关联的值
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        // 初始容量
        private static final int INITIAL_CAPACITY = 16;

        // 实际数据存储的数组
        private Entry[] table;

        ...
    }
}

下面是官方关于 ThreadLocalMap 内部类的说明

ThreadLocalMap 是一个定制的哈希映射,仅适用于维护线程本地值。没有操作被导出到 ThreadLocal类之外。这个类是包私有的,允许在类 Thread 中声明字段(ThreadLocalThread 在同一个包下)。为了帮助处理非常大和长期的使用,散列表条目使用 WeakReferences 作为键。但是,由于没有使用引用队列,所以只有在表空间耗尽时才保证删除过时的条

ThreadLocalMap 是一个保存 ThreadLocal 对象的 map(以 ThreadLocal 实例为 key,其存储的数据为 value 的),不过是经过了两层包装的 ThreadLocal 对象:

  • 第一层是使用 WeakReference<ThreadLocal<?>>ThreadLocal 对象变成弱引用对象;
  • 第二层是定义了一个专门的类 Entry 来扩展 WeakReference<ThreadLocal<?>>

下面是官方关于Entry内部类的说明:

这个散列映射中的条目扩展WeakReference,使用它的主引用字段作为键(始终是ThreadLocal对象)。注意,空键(即entry.get() == null)意味着该键不再被引用,因此条目可以从表中删除。这样的条目在下面的代码中称为“陈旧条目(stale entries)”。

下面是一个演示代码,确认ThreadLocalMap的结构

public static ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
public static ThreadLocal<Long> threadLocal2 = new ThreadLocal<>();
public static ThreadLocal<String> threadLocal3 = new ThreadLocal<>();

@Test
public void test3(){

    threadLocal.set(1);

    threadLocal2.set(2L);

    threadLocal3.set("3");

    Thread thread = Thread.currentThread();
}

通过断点调试模式,可观察到当前线程 thread 内部有 ThreadLocal.ThreadLocalMap 类型的变量threadLocals,并且 threadLocals 的本质是一个类似 Map 类型的结构,其中元素存储于 table 数组中。

image-20220813191733803

2、set()

public void set(T value) {
    //(1)获取当前线程(调用者线程)
    Thread t = Thread.currentThread();
    //(2)从当前线程的成员变量中获取ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    //(3)如果map不为null,就直接添加本地变量,key为当前定义的ThreadLocal变量的this引用,值为添加的本地变量值
    if (map != null)
        map.set(this, value);
    //(4)如果map为null,说明首次添加,需要首先创建出对应的map
    else
        createMap(t, value);
}

在上面的代码中,(2)处调用getMap方法获得当前线程对应的threadLocals,该方法代码如下

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals; //获取线程自己的变量threadLocals,并绑定到当前调用线程的成员变量threadLocals上
}

如果调用getMap方法返回值不为null,就直接将value值设置到threadLocals中(key为当前线程引用,值为本地变量);如果getMap方法返回null说明是第一次调用set方法(前面说到过,threadLocals默认值为null,只有调用set方法的时候才会创建map),这个时候就需要调用createMap方法创建threadLocals,该方法如下所示

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

createMap方法不仅创建了threadLocals,同时也将要添加的本地变量值添加到了threadLocals中。

3、get()

在get方法的实现中,首先获取当前调用者线程,如果当前线程的 threadLocals 不为null,就直接返回当前线程绑定的本地变量值,否则执行setInitialValue方法初始化threadLocals变量。在setInitialValue方法中,类似于set方法的实现,都是判断当前线程的threadLocals变量是否为null,是则添加本地变量(这个时候由于是初始化,所以添加的值为null),否则创建threadLocals变量,同样添加的值为null。

public T get() {
    //(1)获取当前线程
    Thread t = Thread.currentThread();
    //(2)获取当前线程的threadLocals变量
    ThreadLocalMap map = getMap(t);
    //(3)如果threadLocals变量不为null,就可以在map中查找到本地变量的值
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    //(4)执行到此处,threadLocals为null,说明ThreadLocal未初始化,即没有调用set()方法
    // 则需要先初始化,如果设置了初始值则返回,如果没设置,返回null
    return setInitialValue();
}

private T setInitialValue() {
    //protected T initialValue() {return null;}
    T value = initialValue();
    //获取当前线程
    Thread t = Thread.currentThread();
    //以当前线程作为key值,去查找对应的线程变量,找到对应的map
    ThreadLocalMap map = getMap(t);
    //如果map不为null,就直接添加本地变量,key为当前线程,值为添加的本地变量值
    if (map != null)
        map.set(this, value);
    //如果map为null,说明首次添加,需要首先创建出对应的map
    else
        createMap(t, value);
    return value;
}

4、remove()

  remove方法判断该当前线程对应的threadLocals变量是否为null,不为null就直接删除当前线程中指定的threadLocals变量

public void remove() {
    //获取当前线程绑定的threadLocals
     ThreadLocalMap m = getMap(Thread.currentThread());
     //如果map不为null,就移除当前线程中指定ThreadLocal实例的本地变量
     if (m != null)
         m.remove(this);
 }

remove() 方法实现如下

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

5、threadLocals

如下图所示:每个线程内部有一个名为threadLocals的成员变量,该变量的类型为ThreadLocal.ThreadLocalMap类型(类似于一个HashMap),其中的key为当前定义的ThreadLocal变量的this引用,value为我们使用set方法设置的值。每个线程的本地变量存放在自己的本地内存变量threadLocals中,如果当前线程一直不消亡,那么这些本地变量就会一直存在(所以可能会导致内存溢出),因此使用完毕需要将其remove掉。

img

四、内存泄漏问题

1、为何会引发内存泄漏

内存泄漏:不在被使用的对象一直占用内存不能被回收

下面是阿里巴巴开发手册对 ThreadLocal 内存泄漏的描述。

【强制】必须回收自定义的 ThreadLocal 变量,尤其是在线程池的场景下,线程经常被复用,如果不清理自定义的 ThreadLocal 变量,可能会影响后续的业务逻辑和造成内存泄漏的现象。尽量在代理中使用 try-finally 回收

threadLocal.set(1);

try{
    //...
}finally {
    threadLocal.remove();
}

2、为何要用弱引用

首先要知道弱引用的特性:当对象只有弱引用引用他时,会被 gc 回收

function1 执行完毕后,栈帧销毁强引用 tl 也就没有了。但此时线程的 ThreadLocalMap 里的某个 key 引用还指向这个对象。

  • 若 key 的引用是强引用,就会导致 key 指向的 ThreadLocal 对象及对应的 value 值无法被回收,造成内存泄漏
  • 若 key 的引用是弱引用,当方法执行完毕后,该 key 指向的 ThreadLocal 只剩下 key 的弱引用,当下次 gc 时,便会回收该ThreadLocal ,从而使 Entry 的 key 指向置为 null

image-20220814191006344

  • 使用线程池复用线程的情况下:Entry 中的 key 是弱引用,当 ThreadLocal 外部的强引用被置为 null 后,系统发生 gc 时,根据可达性分析,这个 ThreadLocal 实例没有一条链路能够引用他,就会被回收。这样一来,ThreadLocalMap 中就会出现 key 为 null 脏 Entry ,且这些 脏 Entry 的 value 一直强引用对应的对象,无法被回收,从而导致内存泄漏。(所以必须手动调用 remove()

  • 如果当线程运行结束时,ThreadLocal,ThreadLocalMap,Entry 没有引用链可达时,就会被回收。

3、清除脏Entry

虽然弱引用保证了 key 指向的 ThreadLocal 对象能及时被回收,但 value 指向的对象是需要在 ThreadLocalMap 中调用 get()set()remove()方法遍历整个 Map 发现 key 为 null 时才会进行回收,因此弱引用也不能保证 100% 不出现内存泄漏的情况,不仅仅是内存泄漏的情况,因线程池中的线程是重复使用的,意味着线程的 ThreadLocalMap 也是重复使用的,如果不手动调用 remove() 方法,那么后面的线程可能会获取到之前线程遗留的 value 值,引发 BUG

由于Key是弱引用,因此ThreadLocal可以通过key.get()==null来判断Key是否已经被回收,如果Key被回收,就说明当前Entry是一个废弃的过期节点,ThreadLocal会自发的将其清理掉。

ThreadLocal会在以下过程中清理过期节点:

  • 调用set()方法时,采样清理、全量清理,扩容时还会继续检查。

  • 调用get()方法,没有直接命中,向后环形查找时。

  • 调用remove()时,除了清理当前Entry,还会向后继续清理。

(1)set()的清理逻辑

当线程调用ThreadLocal.set(T value)时,它会将ThreadLocal对象作为Key,值作为value设置到ThreadLocalMap中,源码如下:

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算下标,算法:hashCode & (len - 1),和HashMap一样,这里不详叙。
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         /*
         如果下标元素不是null,有两种情况:
         1.同一个Key,覆盖value。
         2.哈希冲突了。
          */
         e != null;
         /*
         哈希冲突的解决方式:开放定址法的线性探测。
         当前下标被占用了,就找next,找到尾巴还没找到就从头开始找。
         直到找到没有被占用的下标。
          */
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            // 相同的Key,则覆盖value。
            e.value = value;
            return;
        }

        if (k == null) {
            /*
            下标被占用,但是Key.get()为null。说明ThreadLocal被回收了。
            需要进行替换。
             */
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    /*
    1.判断是否可以清理一些槽位。
    2.如果清理成功,就无需扩容了,因为已经腾出一些位置留给下次使用。
    3.如果清理失败,则要判断是否需要扩容。
     */
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

如果Entry.get()==null说明发生哈希冲突了,且旧Key已经被回收了,此时ThreadLocal会替换掉旧的value,避免发生「内存泄漏」。
如果没有哈希冲突,ThreadLocal仍然会调用cleanSomeSlots来清理部分节点,源码如下:

/*
清理部分槽位。
1.如果清理成功,就不用扩容了,因为已经腾出一部分位置了。
2.出于性能考虑,不会做所有元素做清理工作,而是采样清理。
set()时,n=size,搜索范围较小。
 */
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 一旦搜索到了过期元素,则n=len,扩大搜索范围
            n = len;
            removed = true;
            // 真正清理的逻辑
            i = expungeStaleEntry(i);
        }
        /*
        采样规则: n >>>= 1 (折半)
        例:100 > 50 > 25 > 12 > 6 > 3 > 1
         */
    } while ( (n >>>= 1) != 0);
    return removed;
}

真正的清理逻辑在expungeStaleEntry()中,源码如下:

/*
删除过期的元素:占用下标,但是ThreadLocal实例已经被回收的元素。
 */
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 清理当前Entry
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    // 继续往后寻找,直到遇到null结束
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            // 再次发现过期元素,清理掉
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 处理重新哈希的逻辑
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

清理时,并不是只清理掉当前Entry就结束了,而是会往后环形的继续寻找过期的Entry,只要找到了就清理,直到遇到tab[i]==null就结束,清理的过程中还会对元素做一个rehash的操作。

(2)get()的清理逻辑

线程调用ThreadLocal.get()时,会从ThreadLocalMap.getEntry(this)去查找,源码如下:

/*
通过Key获取Entry
 */
private Entry getEntry(ThreadLocal<?> key) {
    // 计算下标
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key) {
        // 如果对应下标节点不为null,且Key相等,则命中直接返回
        return e;
    } else {
        /*
        否则有两种情况:
        1.Key不存在。
        2.哈希冲突了,需要向后环形查找。
         */
        return getEntryAfterMiss(key, i, e);
    }
}

如果命中则直接返回,如果没有命中则可能是哈希冲突了、或者Key不存在/已被回收,接着调用getEntryAfterMiss()查找,这里也会进行过期节点的清理:

/*
无法直接命中的查找逻辑
 */
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {// e==null说明Key不存在,直接返回null
        ThreadLocal<?> k = e.get();
        if (k == key)
            // 找到了,说明是哈希冲突
            return e;
        if (k == null)
            // Key存在,但是过期了,需要清理掉,并且返回null
            expungeStaleEntry(i);
        else
            // 向后环形查找
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

(3)remove()的清理逻辑

线程调用ThreadLocal.remove()本身就是清理当前节点的,但是为了避免发生「内存泄漏」,ThreadLocal还会检查容器中是否还有其他过期节点,如果发现也会一并清理,主要逻辑在ThreadLocalMap.remove()中:

// 通过Key删除Entry
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算下标
    int i = key.threadLocalHashCode & (len-1);
    /*
    删除也是一样,由于存在哈希冲突,不能直接定位到下标后直接删除。
    删除前需要确认Key是否相等,如果不等需要往后环形查找。
     */
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            /*
            找到了就清理掉。
            这里并没有直接清理,而是将Key的Reference引用清空了,
            然后再调用expungeStaleEntry()清理过期元素。
            顺便还可以清理后续节点。
             */
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

4、内存泄漏案例

(1)非静态变量

public class Test2 {

    // byte数组大小 50M
    private final static Integer size = 50 * 1024 * 1024;

    public static void main(String[] args) throws InterruptedException {
        MyTask myTask = new MyTask();
        myTask.getLocal();
        myTask.getByte();

        System.gc();
        TimeUnit.SECONDS.sleep(600);
    }

    static class MyTask{

        public void getLocal(){
            ThreadLocal<Object> threadLocal = new ThreadLocal<>();
            threadLocal.set(new byte[50 * 1024 * 1025]);
        }

        public void getByte(){
            byte[] bytes = new byte[size];
        }
    }
}

(2)线程池

可通过 jconsole 查看内存使用情况

public class Test1 {

    // byte数组大小 50M
    private final static Integer size = 50 * 1024 * 1024;

    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(5);

        // 提交任务
        for (int i = 0; i < 10; i++) {
            executorService.execute(new MyTask());
        }

        // 确保任务执行完毕
        TimeUnit.SECONDS.sleep(3);

        System.gc();

        TimeUnit.SECONDS.sleep(600);
    }

    /**
     * 演示内存泄漏
     */
    static class MyTask implements Runnable {

        private static ThreadLocal<Object> threadLocal = new ThreadLocal<>();

        @Override
        public void run() {
            // 在任务中设置ThreadLocal的值
            threadLocal.set(new byte[size]);

            // 执行任务逻辑

            // 忘记清理ThreadLocal的值
        }
    }

    /**
     * 改进后
     */
    static class MyTask2 implements Runnable {

        private static ThreadLocal<Object> threadLocal = new ThreadLocal<>();

        @Override
        public void run() {
            // 在任务中设置ThreadLocal的值
            threadLocal.set(new byte[size]);

            try {
                // 执行任务逻辑
            } finally {
                // 清理ThreadLocal的值
                threadLocal.remove();
            }
        }
    }
}

五、InheritableThreadLocal

1、继承性

同一个ThreadLocal变量在父线程中被设置值后,在子线程中是获取不到的。(threadLocals中为当前调用线程对应的本地变量,所以二者自然是不能共享的),所以 ThreadLocal 不支持继承性

public class ThreadLocalTest {

    public static void main(String[] args) {

        ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
        threadLocal.set(111);

        new Thread(() -> System.out.println(threadLocal.get())).start();
    }
}

上面这理所当然的取不出来值,打印结果为null,因为子线程的threadLocals并没有存主线程的值,而要想解决这个问题,就需要认识一个新的工具 InheritableThreadLocal

2、API

这个类继承了 ThreadLocal 来提供从父线程到子线程的值的继承:当创建子线程时,子线程接收父线程有值的所有可继承的线程本地变量的初始值。通常,子进程的值与父进程的值是相同的;但是,通过重写这个类中的 childValue 方法,子对象的值可以成为父对象的任意函数。

当在变量中维护的每个线程属性(例如,用户ID、事务ID)必须自动传输给创建的任何子线程时,可继承的线程本地变量优先于普通线程本地变量

InheritableThreadLocal 的使用方法基本和 ThreadLocal 相同,同样可以使用 2 种构造方法

@Test
public void test6(){

    //InheritableThreadLocal<Integer> inheritableThreadLocal = new InheritableThreadLocal<>();
    ThreadLocal<Integer> inheritableThreadLocal = InheritableThreadLocal.withInitial(() -> 1);

    inheritableThreadLocal.set(1);

    new Thread(()->{
        System.out.println(inheritableThreadLocal.get());
    }).start();
}

3、源码

之前 ThreadLocalMap 中说到 Thread 中有以下 2 个成员变量:

ThreadLocal.ThreadLocalMap threadLocals = null;

ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

ThreadLocal 操作的是threadLocals,而 InheritableThreadLocal 操作的是 inheritableThreadLocals 变量。

InheritableThreadLocal 类源码如下:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    /**
     * 为这个可继承的线程局部变量计算子线程的初始值,作为子线程创建时父线程值的函数。在子线程启动之前,从父线程内部调用此方法。
     * 此方法仅返回其输入参数,如果需要不同的行为,则应该重写此方法。
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

从上面代码可以看出,InheritableThreadLocal 类继承了 ThreadLocal 类,并重写了 childValuegetMapcreateMap 三个方法。其中 createMap 方法在被调用(当前线程调用set方法时得到的map为null的时候需要调用该方法)的时候,创建的是 inheritableThreadLocal 而不是threadLocals 。同理,getMap 方法在当前调用者线程调用 get 方法的时候返回的也不是threadLocals 而是 inheritableThreadLocal

下面我们看看重写的 childValue 方法在什么时候执行,怎样让子线程访问父线程的本地变量值。我们首先从 Thread 类的初始化方法 init 开始说起

public class Thread implements Runnable {

    ThreadLocal.ThreadLocalMap threadLocals = null;

    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

    /**
    * 参数inheritThreadLocals:如果为真,则从构造线程继承可继承的线程局部变量的初始值
    */
    private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {

        // ...

        Thread parent = currentThread();    

        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

        // ...
    }
}

从上面的线程初始化代码中可以看出,线程初始化时会先判断父线程的 inheritableThreadLocals 是否为空,如果不为空则将父线程的 inheritableThreadLocals 复制到当前线程,createInheritedMap 方法源码如下:

public class ThreadLocal<T> {    

    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }
}

其实总的来说 inheritableThreadLocalsThreadLocal 的不同就只是将数据存储到 Thread 类的inheritableThreadLocals 变量下,并且在创建线程时,将父线程的对应的变量复制到子线程。

六、实战案例

通过过滤器获取请求头中存储的 所属园区id ,并将其存储到 ThreadLocal 中,以便在该次请求的任何地方可获取该园区id

过滤器

/**
 * 从请求的 header 中获取园区id
 */
public class FactoryZoneFilter implements Filter {

    public static final ThreadLocal<String> FACTORY_ZONE_ID = new ThreadLocal<>();

    public static final ThreadLocal<String> CURRENT_ID = new ThreadLocal<>();

    private Logger logger = LoggerFactory.getLogger(FactoryZoneFilter.class);

    public static String getFactoryZoneId(){
        return FACTORY_ZONE_ID.get();
    }

    public static String getCurrentId(){
        return CURRENT_ID.get();
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;

        String factoryZoneId = request.getHeader("appCategory");
        String currentId = request.getHeader("current-id");

        FACTORY_ZONE_ID.set(factoryZoneId);
        CURRENT_ID.set(currentId);

        logger.debug("当前园区id:{}",factoryZoneId);
        filterChain.doFilter(servletRequest,servletResponse);
        FACTORY_ZONE_ID.remove();
    }

    @Override
    public void destroy() {
        logger.debug("FactoryZoneFilter destory....");
        FACTORY_ZONE_ID.remove();
    }
}

配置过滤器

@Configuration
public class FilterConfig {

    @Bean
    public FilterRegistrationBean factoryZoneFilter(){
        FilterRegistrationBean factoryZoneFilter = new FilterRegistrationBean();
        factoryZoneFilter.setFilter(new FactoryZoneFilter());
        factoryZoneFilter.addUrlPatterns("/*");
        factoryZoneFilter.setName("factoryZoneFilter");
        factoryZoneFilter.setOrder(1);
        return factoryZoneFilter;
    }
}

  目录