ConcurrentHashMap实现原理
ConcurrentHashMap是Java 5中支持高并发、高吞吐量的线程安全HashMap实现。在这之前我对ConcurrentHashMap只有一些肤浅的理解,仅知道它采用了多个锁,大概也足够了。但是在经过一次惨痛的面试经历之后,我觉得必须深入研究它的实现。面试中被问到读是否要加锁,因为读写会发生冲突,我说必须要加锁,我和面试官也因此发生了冲突,结果可想而知。还是闲话少说,通过仔细阅读源代码,现在总算理解ConcurrentHashMap实现机制了,其实现之精巧,令人叹服,与大家共享之。
实现原理
锁分离 (Lock Stripping)
ConcurrentHashMap允许多个修改操作并发进行,其关键在于使用了锁分离技术。它使用了多个锁来控制对hash表的不同部分进行的修改。ConcurrentHashMap内部使用段(Segment)来表示这些不同的部分,每个段其实就是一个小的hash table,它们有自己的锁。只要多个修改操作发生在不同的段上,它们就可以并发进行。
有些方法需要跨段,比如size()和containsValue(),它们可能需要锁定整个表而而不仅仅是某个段,这需要按顺序锁定所有段,操作完毕后,又按顺序释放所有段的锁。这里“按顺序”是很重要的,否则极有可能出现死锁,在ConcurrentHashMap内部,段数组是final的,并且其成员变量实际上也是final的,但是,仅仅是将数组声明为final的并不保证数组成员也是final的,这需要实现上的保证。这可以确保不会出现死锁,因为获得锁的顺序是固定的。不变性是多线程编程占有很重要的地位,下面还要谈到。
Java代码
/** * The segments, each of which is a specialized hash table */ final Segment<K,V>[] segments;
static final class HashEntry<K,V> { final K key; final int hash; volatile V value; final HashEntry<K,V> next; } private static int hash(int h) { // Spread bits to regularize both segment and index locations, // using variant of single-word Wang/Jenkins hash. h += (h << 15) ^ 0xffffcd7d; h ^= (h >>> 10); h += (h << 3); h ^= (h >>> 6); h += (h << 2) + (h << 14); return h ^ (h >>> 16); } final Segment<K,V> segmentFor(int hash) { return segments[(hash >>> segmentShift) & segmentMask]; } public class ConcurrentHashMap<K, V> extends AbstractMap<K, V> implements ConcurrentMap<K, V>, Serializable { /** * Mask value for indexing into segments. The upper bits of a * key's hash code are used to choose the segment. */ final int segmentMask; /** * Shift value for indexing within segments. */ final int segmentShift; /** * The segments, each of which is a specialized hash table */ final Segment<K,V>[] segments; } static final class Segment<K,V> extends ReentrantLock implements Serializable { private static final long serialVersionUID = 2249069246763182397L; /** * The number of elements in this segment's region. */ transient volatile int count; /** * Number of updates that alter the size of the table. This is * used during bulk-read methods to make sure they see a * consistent snapshot: If modCounts change during a traversal * of segments computing size or checking containsValue, then * we might have an inconsistent view of state so (usually) * must retry. */ transient int modCount; /** * The table is rehashed when its size exceeds this threshold. * (The value of this field is always <tt>(int)(capacity * * loadFactor)</tt>.) */ transient int threshold; /** * The per-segment table. */ transient volatile HashEntry<K,V>[] table; /** * The load factor for the hash table. Even though this value * is same for all segments, it is replicated to avoid needing * links to outer object. * @serial */ final float loadFactor; } public V remove(Object key) { hash = hash(key.hashCode()); return segmentFor(hash).remove(key, hash, null); } V remove(Object key, int hash, Object value) { lock(); try { int c = count - 1; HashEntry<K,V>[] tab = table; int index = hash & (tab.length - 1); HashEntry<K,V> first = tab[index]; HashEntry<K,V> e = first; while (e != null && (e.hash != hash || !key.equals(e.key))) e = e.next; V oldValue = null; if (e != null) { V v = e.value; if (value == null || value.equals(v)) { oldValue = v; // All entries following removed node can stay // in list, but all preceding ones need to be // cloned. ++modCount; HashEntry<K,V> newFirst = e.next; for (HashEntry<K,V> p = first; p != e; p = p.next) newFirst = new HashEntry<K,V>(p.key, p.hash, newFirst, p.value); tab[index] = newFirst; count = c; // write-volatile } } return oldValue; } finally { unlock(); } } V put(K key, int hash, V value, boolean onlyIfAbsent) { lock(); try { int c = count; if (c++ > threshold) // ensure capacity rehash(); HashEntry<K,V>[] tab = table; int index = hash & (tab.length - 1); HashEntry<K,V> first = tab[index]; HashEntry<K,V> e = first; while (e != null && (e.hash != hash || !key.equals(e.key))) e = e.next; V oldValue; if (e != null) { oldValue = e.value; if (!onlyIfAbsent) e.value = value; } else { oldValue = null; ++modCount; tab[index] = new HashEntry<K,V>(key, hash, first, value); count = c; // write-volatile } return oldValue; } finally { unlock(); } } V get(Object key, int hash) { if (count != 0) { // read-volatile HashEntry<K,V> e = getFirst(hash); while (e != null) { if (e.hash == hash && key.equals(e.key)) { V v = e.value; if (v != null) return v; return readValueUnderLock(e); // recheck } e = e.next; } } return null; } V readValueUnderLock(HashEntry<K,V> e) { lock(); try { return e.value; } finally { unlock(); } } boolean containsKey(Object key, int hash) { if (count != 0) { // read-volatile HashEntry<K,V> e = getFirst(hash); while (e != null) { if (e.hash == hash && key.equals(e.key)) return true; e = e.next; } } return false; } public int size() { final Segment<K,V>[] segments = this.segments; long sum = 0; long check = 0; int[] mc = new int[segments.length]; // Try a few times to get accurate count. On failure due to // continuous async changes in table, resort to locking. for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) { check = 0; sum = 0; int mcsum = 0; for (int i = 0; i < segments.length; ++i) { sum += segments[i].count; mcsum += mc[i] = segments[i].modCount; } if (mcsum != 0) { for (int i = 0; i < segments.length; ++i) { check += segments[i].count; if (mc[i] != segments[i].modCount) { check = -1; // force retry break; } } } if (check == sum) break; } if (check != sum) { // Resort to locking all segments sum = 0; for (int i = 0; i < segments.length; ++i) segments[i].lock(); for (int i = 0; i < segments.length; ++i) sum += segments[i].count; for (int i = 0; i < segments.length; ++i) segments[i].unlock(); } if (sum > Integer.MAX_VALUE) return Integer.MAX_VALUE; else return (int)sum; } public boolean containsValue(Object value) { if (value == null) throw new NullPointerException(); // See explanation of modCount use above final Segment<K,V>[] segments = this.segments; int[] mc = new int[segments.length]; // Try a few times without locking for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) { int sum = 0; int mcsum = 0; for (int i = 0; i < segments.length; ++i) { int c = segments[i].count; mcsum += mc[i] = segments[i].modCount; if (segments[i].containsValue(value)) return true; } boolean cleanSweep = true; if (mcsum != 0) { for (int i = 0; i < segments.length; ++i) { int c = segments[i].count; if (mc[i] != segments[i].modCount) { cleanSweep = false; break; } } } if (cleanSweep) return false; } // Resort to locking all segments for (int i = 0; i < segments.length; ++i) segments[i].lock(); boolean found = false; try { for (int i = 0; i < segments.length; ++i) { if (segments[i].containsValue(value)) { found = true; break; } } } finally { for (int i = 0; i < segments.length; ++i) segments[i].unlock(); } return found; } abstract class Hash_Iterator{ int nextSegmentIndex; int nextTableIndex; HashEntry<K,V>[] currentTable; HashEntry<K, V> nextEntry; HashEntry<K, V> lastReturned; } final void advance() { if (nextEntry != null && (nextEntry = nextEntry.next) != null) return; while (nextTableIndex >= 0) { if ( (nextEntry = currentTable[nextTableIndex--]) != null) return; } while (nextSegmentIndex >= 0) { Segment<K,V> seg = segments[nextSegmentIndex--]; if (seg.count != 0) { currentTable = seg.table; for (int j = currentTable.length - 1; j >= 0; --j) { if ( (nextEntry = currentTable[j]) != null) { nextTableIndex = j - 1; return; } } } } }