使用JDK8computeIfAbsent遇到的问题

doMore 605 2023-02-18

java.util.Map#computeIfAbsent问题

JDK8 HashMap

问题1

HashMap 有两个 computeIfAbsent 调用,不同的 Key 但是 hashCode 相同,equals 不相等。

代码

// 测试代码
public class ComputeIfAbsentTest {

    public static void main(String[] args) {
        Map<Key, String> m = new HashMap<>();
        m.computeIfAbsent(new Key("firstKey"), k -> {
            m.computeIfAbsent(new Key("secondKey"), sk -> "secondKey"); 
            return "firstValue";
        });

        System.out.println("Map.size(): " + m.size()); // size == 2

		// Map.entrySet().toArray().length: 1
		// firstKey
        System.out.println("Map.entrySet().toArray().length: " + m.entrySet().toArray().length);
        for (Key k : m.keySet()) {
            System.out.println(k.value);
        }
    }

    private static class Key {
        private final String value;

        private Key(String val) {
            value = val;
        }

        @Override
        public int hashCode() {
            return 1; // 模拟 hashCode 相同的情况
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            Key other = (Key) obj;
            return Objects.equals(value, other.value);
        }
    }

}

源码分析

// java.util.HashMap#computeIfAbsent
// jdk 8
public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
        // ......
        V v = mappingFunction.apply(key);
        if (v == null) {  // 如果旧的值和新值都为null 直接返回
            return null;
        } else if (old != null) { // 替换旧的值
            old.value = v;
            afterNodeAccess(old);
            return v;
        }
        else if (t != null) // 如果是 红黑树 节点 进行特殊操作
            t.putTreeVal(this, tab, hash, key, v);
        else {
            // 对 相应 数组位置直接更换
            // 上述问题出现的关键就是该行代码,两个 Key 的HashCode相同时,此处已经有值,会出现覆盖情况
            tab[i] = newNode(hash, key, v, first);
            if (binCount >= TREEIFY_THRESHOLD - 1)
                treeifyBin(tab, hash);
        }
        ++modCount;
	    // SIZE +1 
        ++size;
        afterNodeInsertion(true);
        return v;
    }


// 实际上 jdk 9 已经修复问题,下面截取 jdk 17 的代码做对比
// jdk 17

public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
        // ......
    	// 此处先 保留一份 快照,如果 apply 操作 使值发生了变化,则抛出异常,该操作不被允许
   		// modCount 作用: 数量或者其他方式使结构发生变化的标记
        int mc = modCount;
        V v = mappingFunction.apply(key);
        if (mc != modCount) { throw new ConcurrentModificationException(); }
        if (v == null) {
            return null;
        } else if (old != null) {
            old.value = v;
            afterNodeAccess(old);
            return v;
        }
        else if (t != null)
            t.putTreeVal(this, tab, hash, key, v);
        else {
            tab[i] = newNode(hash, key, v, first);
            if (binCount >= TREEIFY_THRESHOLD - 1)
                treeifyBin(tab, hash);
        }
        modCount = mc + 1;
        ++size;
        afterNodeInsertion(true);
        return v;
    }


JDK8 ConcurrentHashMap

问题1

频繁加锁,性能问题

源码分析

// java.util.concurrent.ConcurrentHashMap#computeIfAbsent
// jdk 8
public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
        if (key == null || mappingFunction == null)
            throw new NullPointerException();
        int h = spread(key.hashCode());
        V val = null;
        int binCount = 0;
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            if (tab == null || (n = tab.length) == 0)
                tab = initTable();
            else if ((f = tabAt(tab, i = (n - 1) & h)) == null) {
                // key 不存在 且 hash 对应的位置还没值
                Node<K,V> r = new ReservationNode<K,V>();
                synchronized (r) {
                    // cas 操作设置值
                    if (casTabAt(tab, i, null, r)) {
                        binCount = 1;
                        Node<K,V> node = null;
                        try {
                            if ((val = mappingFunction.apply(key)) != null)
                                node = new Node<K,V>(h, key, val, null);
                        } finally {
                            setTabAt(tab, i, node);
                        }
                    }
                }
                if (binCount != 0)
                    break;
            }
            else if ((fh = f.hash) == MOVED)
                // 扩容
                tab = helpTransfer(tab, f);
            else {
                boolean added = false;
                // 可以看到,每次的查找都会加锁(synchronized),
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) {
                            // 在链表中查找 key
                        }
                        else if (f instanceof TreeBin) {
                            // 在红黑树中查找 key
                        }
                    }
                }
                // 树化判断
            }
        }
        // 增加数量操作
        return val;
    }

// 如果在 Java 8 的环境下使用 ConcurrentHashMap,一定要注意是否会并发对同一个 key 调用 computeIfAbsent,如果存在需要先尝试调用 get。
//       Map<String, String> m = new ConcurrentHashMap<>();
//       Object result = m.get("Key");
//       if (result == null){
//            result = m.computeIfAbsent("key",k->"value");
//       }



// jdk 17
public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
        if (key == null || mappingFunction == null)
            throw new NullPointerException();
        int h = spread(key.hashCode());
        V val = null;
        int binCount = 0;
        for (Node<K,V>[] tab = table;;) {
            // ......
            else if (fh == h    // 增加1: 对首个节点进行判断,如果是目标值直接返回
                     && ((fk = f.key) == key || (fk != null && key.equals(fk)))
                     && (fv = f.val) != null)
                return fv;
            else {
                boolean added = false;
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) {
                            // ......
                            if ((e = e.next) == null) {
                                    if ((val = mappingFunction.apply(key)) != null) {
                                        // 增加2: 如果节点类型是 ReservationNode,直接抛出异常
                                        if (pred.next != null)
                                            throw new IllegalStateException("Recursive update");
                                        // ......
                                    }
                                    break;
                                }
                        }
                        else if (f instanceof TreeBin) {
                            // ......
                        }
                        // 增加2: 如果节点类型是 ReservationNode,直接抛出异常
                        else if (f instanceof ReservationNode)
                            throw new IllegalStateException("Recursive update");
                    }
                }
                // ......
            }
        }
        // ......
        return val;
    }

相关测试资料

https://blog.csdn.net/wu_weijie/article/details/121899160