Skip to content

Commit 1ca8bc3

Browse files
committed
Add concurrent pool implementation.
1 parent fbaaf78 commit 1ca8bc3

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package org.learningconcurrency
2+
package ch3
3+
4+
5+
6+
import java.util.concurrent.ConcurrentHashMap
7+
import java.util.concurrent.atomic.AtomicReference
8+
import scala.annotation.tailrec
9+
10+
11+
12+
object LockFreePool {
13+
class Pool[T] {
14+
val parallelism = Runtime.getRuntime.availableProcessors * 32
15+
val buckets = new Array[AtomicReference[(List[T], Long)]](parallelism)
16+
for (i <- 0 until buckets.length)
17+
buckets(i) = new AtomicReference((Nil, 0L))
18+
19+
def add(x: T): Unit = {
20+
val i = (Thread.currentThread.getId * x.## % buckets.length).toInt
21+
@tailrec def retry() {
22+
val bucket = buckets(i)
23+
val v = bucket.get
24+
val (lst, stamp) = v
25+
val nlst = x :: lst
26+
val nstamp = stamp + 1
27+
val nv = (nlst, nstamp)
28+
if (!bucket.compareAndSet(v, nv)) retry()
29+
}
30+
retry()
31+
}
32+
33+
def get(): Option[T] = {
34+
val start = (Thread.currentThread.getId % buckets.length).toInt
35+
@tailrec def scan(witness: Long): Option[T] = {
36+
var i = (start + 1) % buckets.length
37+
var sum = 0L
38+
while (i != start) {
39+
val bucket = buckets(i)
40+
41+
@tailrec def retry(): Option[T] = {
42+
bucket.get match {
43+
case (Nil, stamp) =>
44+
sum += stamp
45+
None
46+
case v @ (lst, stamp) =>
47+
val nv = (lst.tail, stamp + 1)
48+
if (bucket.compareAndSet(v, nv)) Some(lst.head)
49+
else retry()
50+
}
51+
}
52+
retry() match {
53+
case Some(v) => return Some(v)
54+
case None =>
55+
}
56+
57+
i = (i + 1) % buckets.length
58+
}
59+
if (sum == witness) None
60+
else scan(sum)
61+
}
62+
scan(-1L)
63+
}
64+
}
65+
66+
def main(args: Array[String]) {
67+
val check = new ConcurrentHashMap[Int, Unit]()
68+
val pool = new Pool[Int]
69+
val p = 8
70+
val num = 1000000
71+
val inserters = for (i <- 0 until p) yield ch2.thread {
72+
for (j <- 0 until num) pool.add(i * num + j)
73+
}
74+
inserters.foreach(_.join())
75+
val removers = for (i <- 0 until p) yield ch2.thread {
76+
for (j <- 0 until num) {
77+
pool.get() match {
78+
case Some(v) => check.put(v, ())
79+
case None => sys.error("Should be non-empty.")
80+
}
81+
}
82+
}
83+
removers.foreach(_.join())
84+
for (i <- 0 until (num * p)) assert(check.containsKey(i))
85+
}
86+
}

0 commit comments

Comments
 (0)