`

多线程备忘笔记

阅读更多

《多处理器编程的艺术》附录A提到的多线程基本问题。包括Java、C#和C(pthreads)的实现:线程创建,管程,线程局部对象和生产者消费者问题的解决(仅供参考)

 

一、C#版,用VS2008测试。

 

 

using System;
//using System.Collections.Generic;
//using System.Linq;
//using System.Text;
using System.Threading;

//see 《多处理器编程的艺术》附录A
//The Art of Multiprocessor Programming
namespace t1
{
    //循环队列,管程锁(监视器Monitor)
    class Queue<T>
    {
        int head; //读出位置
        int tail; //写入位置
        T[] call;
        public Queue(int capacity)
        {
            call = new T[capacity];
            head = tail = 0;
        }

        public void Enq(T x)
        {
            Monitor.Enter(this);
            try
            {
                while (tail - head == call.Length)
                {
                    Monitor.Wait(this); //队列满
                }
                call[(tail++) % call.Length] = x;
                Monitor.Pulse(this); //激活等待的出列线程
            }
            finally
            {
                Monitor.Exit(this);
            }
        }

        public T Deq()
        {
            Monitor.Enter(this);
            try
            {
                while (tail == head)
                {
                    Monitor.Wait(this); //队列满
                }
                T y = call[(head++) % call.Length];
                Monitor.Pulse(this);  //激活等待的进列线程
                return y;
            }
            finally
            {
                Monitor.Exit(this);
            }
        }
    }

    //Thread-Local Objects
    //静态域转换为本地线程对象,作为线程的唯一标识
    class ThreadID
    {
        [ThreadStatic] static int myID;
        static int counter;

        public static int get()
        {
            //只有是未设置ID时(不同的线程)才加一,
            //如果线程已经get了一次,就不会加一
            if (myID == 0) 
            {
                myID = Interlocked.Increment(ref counter);
            }
            return myID - 1;
        }
    }

    //共享计数器,临界区
    class Counter
    {
        private int value;
        public Counter(int i)
        {
            value = i;
        }
        
        //加一,返回加一前的值
        public int GetAndIncrement()
        {
            lock (this)
            {
                return value++;
            }
        }
    }

    //测试主入口
    class Program
    {
        static void HelloWorld()
        {
            Console.WriteLine("Hello World");
        }

        //TODO:创建线程
        static void test1()
        {
            ThreadStart hello = new ThreadStart(delegate()
            {
                Console.WriteLine("Hello World");
            });
            Thread thread = new Thread(hello);
            thread.Start();
            thread.Join();
            thread = new Thread(new ThreadStart(HelloWorld));
            thread.Start();
            thread.Join();
        }

        //TODO:多线程同步与本地线程对象
        static void test2()
        {
            Counter counter = new Counter(0);
            Thread[] thread = new Thread[8];
            for (int i = 0; i < thread.Length; i++)
            {
                String message = "Hello world from thread" + i;
                ThreadStart hello = delegate() {
                    Console.WriteLine(message);
                    Console.WriteLine(">>>ThreadID:" + ThreadID.get() + 
                        ", and get again:" + ThreadID.get());
                    Console.WriteLine(">>>>>locked counter:" + counter.GetAndIncrement());
                };
                thread[i] = new Thread(hello);
            }
            for (int i = 0; i < thread.Length; i++)
            {
                thread[i].Start();
            }
            //等待线程结束
            for (int i = 0; i < thread.Length; i++)
            {
                thread[i].Join();
            }
            Console.WriteLine("done!");
        }

        //TODO:生产者-消费者问题,双线程共享一个FIFO队列
        //The Producer–Consumer Problem
        static void test3()
        {
            Queue<int> queue = new Queue<int>(10);
            //默认是使用时间做随机种子
            Random randomProducer = new Random();
            Random randomConsumer = new Random();
            ThreadStart producer = new ThreadStart(delegate()
            {
                Console.WriteLine("producer thread start");
                for (int i = 0; i < 20; i++)
                {
                    queue.Enq(i);
                    Console.WriteLine("<< Producer put:" + i);
                    Thread.Sleep(randomProducer.Next(100));
                    //Console.WriteLine(randomConsumer.Next(100));
                }
            });
            ThreadStart consumer = new ThreadStart(delegate()
            {
                Console.WriteLine("consumer thread start");
                for (int i = 0; i < 20; i++)
                {
                    int value = queue.Deq();
                    Console.WriteLine(">> Consumer got:" + value);
                    Thread.Sleep(randomConsumer.Next(100));
                    //Console.WriteLine(randomConsumer.Next(100));
                }
            });
            //new Thread[2]
            Thread[] thread = {new Thread(producer), new Thread(consumer)};
            for (int i = 0; i < thread.Length; i++)
            {
                thread[i].Start();
            }
            //等待线程结束
            for (int i = 0; i < thread.Length; i++)
            {
                thread[i].Join();
            }
            Console.WriteLine("done!");
        }

        static void Main(string[] args)
        {
            test1();
            test2();
            test3();
            Console.ReadKey();
        }
    }
}

 

二、Pthreads版,C代码,用cygwin测试(未考虑free问题)

 

 

 

/*
see The Art of Multiprocessor Programming

In cygwin:
> rm -f *.exe && gcc main.c && ./a.exe
*/
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <pthread.h>
#define NUM_THREADS 8
#define QSIZE 10
//-----------------------------------------------
//Queue, and monitor
typedef struct {
	int buf[QSIZE];
	long head, tail;
	pthread_mutex_t *mutex;
	pthread_cond_t *notFull, *notEmpty;
} queue;

void queue_enq(queue* q, int item) {
	//or use pthread_mutex_trylock to return immediately
	pthread_mutex_lock(q->mutex); 
	 
	while(q->tail - q->head == QSIZE) {
		//condition variable and lock(mutex)
		pthread_cond_wait(q->notFull, q->mutex); 
	}
	q->buf[q->tail % QSIZE] = item;
	q->tail++;
	pthread_mutex_unlock(q->mutex);
	
	//or use pthread_cond_broadcast to notify all
	pthread_cond_signal(q->notEmpty); 
}

int queue_deq(queue* q) {
	int result;
	pthread_mutex_lock(q->mutex);
	while(q->tail == q->head) {
		pthread_cond_wait(q->notEmpty, q->mutex);
	}
	result = q->buf[q->head % QSIZE];
	q->head++;
	pthread_mutex_unlock(q->mutex);
	pthread_cond_signal(q->notFull);
	return result;
}

queue* queue_init() {
	queue *q;
	q = (queue*)malloc(sizeof(queue));
	if(q == NULL) 
		return NULL;
	q->head = 0;
	q->tail = 0;
	q->mutex = (pthread_mutex_t*)malloc(sizeof(pthread_mutex_t));
	pthread_mutex_init(q->mutex, NULL);
	q->notFull = (pthread_cond_t*)malloc(sizeof(pthread_cond_t));
	pthread_cond_init(q->notFull, NULL);
	q->notEmpty = (pthread_cond_t*)malloc(sizeof(pthread_cond_t));
	pthread_cond_init(q->notEmpty, NULL);
	return q;
}

//-----------------------------------------------
//Thread-Local Objects
pthread_key_t key;
int counter;
pthread_mutex_t mutex;

void threadID_init() {
	pthread_mutex_init(&mutex, NULL);
	pthread_key_create(&key, NULL);
	counter = 0;
}

int threadID_get() {
	int* id = (int*)pthread_getspecific(key);
	if(id == NULL) {
		id = (int *)malloc(sizeof(int));
		pthread_mutex_lock(&mutex);
		*id = counter++;
		pthread_setspecific(key, id);
		pthread_mutex_unlock(&mutex);
	}
	return *id;
}

//-----------------------------------------------
//Counter, locked
typedef struct {
	int value;
	pthread_mutex_t *mutex;
} locked_counter;

locked_counter *lockedCounter;

locked_counter* locked_counter_init() {
	locked_counter* c = (locked_counter*)malloc(sizeof(locked_counter));
	if(c == NULL)
		return NULL;
	c->value = 0;
	c->mutex = (pthread_mutex_t*)malloc(sizeof(pthread_mutex_t));
	pthread_mutex_init(c->mutex, NULL);
	return c;
}

int getAndIncrement(locked_counter* c) {
	int i;
	pthread_mutex_lock(c->mutex);
	i = c->value;
	c->value++;
	pthread_mutex_unlock(c->mutex);
	return i;
}

//-----------------------------------------------
//delegate
//void* (*thread_function)(void*)
void* helloworld(void *arg) {
	printf("Hello World\n");
}

void* hello(void *arg) {
	printf("Hello from thread %d\n", (int)arg);
	printf(">>>ThreadID:%d, and get again:%d\n", 
		threadID_get(), threadID_get());
	printf(">>>>>locked counter:%d\n", getAndIncrement(lockedCounter));
}

void* producer(void *arg) {
	int i;
	queue *q;
	q = (queue*)arg;
	
	//see http://www.cplusplus.com/reference/clibrary/cstdlib/rand/
	srand(time(NULL));
	
	printf("producer thread start\n");
	for (i = 0; i < 20; i++)
	{
		queue_enq(q, i);
		printf("<< Producer put:%d\n", i);
		//1000 * 1000 micro seconds == 1 second
		//or use sleep
		usleep((rand() % 100) * 1000);
		//printf("%d\n", (rand() % 100));
	}
}

void* consumer(void *arg) {
	int i;
	queue *q;
	q = (queue*)arg;

	//see http://www.cplusplus.com/reference/clibrary/cstdlib/rand/
	srand(time(NULL) + 12345);
	
	printf("consumer thread start\n");
	for (i = 0; i < 20; i++)
	{
		int value = queue_deq(q);
		printf(">> Consumer got:%d\n", i);
		usleep((rand() % 100) * 1000);
		//printf("%d\n", (rand() % 100));
	}
}

//-----------------------------------------------

void test1() {
	pthread_t thread;
	if(pthread_create(&thread, NULL, helloworld, NULL) != 0) {
		printf("pthread_create() error\n");
		exit(1);
	}
	pthread_join(thread, NULL);
}

void test2() {
	pthread_t thread[NUM_THREADS];
	int i;
	
	//create locked_counter
	lockedCounter = locked_counter_init();
	if(lockedCounter == NULL) {
		printf("locked_counter_init() error\n");
		exit(1);
	}
	
	// create threadID
	threadID_init();
	
	for(i = 0; i < NUM_THREADS; i++) {
		//Create thread and start immediately
		if(pthread_create(&thread[i], NULL, hello, (void *)i) != 0) {
			printf("pthread_create() error\n");
			exit(1);
		}
	}
	for(i = 0; i < NUM_THREADS; i++) {
		pthread_join(thread[i], NULL);
	}	
	printf("done!\n");
}

void test3() {
	pthread_t threadProducer, threadConsumer;
	queue *q;
	
	//create queue
	q = queue_init();
	if(q == NULL) {
		printf("queue_init() error\n");
		exit(1);	
	}
	//see http://www.cplusplus.com/reference/clibrary/cstdlib/rand/
	//srand(time(NULL));
	
	if(pthread_create(&threadProducer, NULL, producer, (void *)q) != 0) {
		printf("pthread_create() error\n");
		exit(1);
	} 
	if(pthread_create(&threadConsumer, NULL, consumer, (void *)q) != 0) {
		printf("pthread_create() error\n");
		exit(1);
	}
	pthread_join(threadProducer, NULL);
	pthread_join(threadConsumer, NULL);
	printf("done!\n");
}

int main() {
	test1();
	test2();
	test3();
	
	return 0;
}

 

三、Java版,用Eclipse和JDK6测试(使用并发库实现的带锁循环队列没有测试)

 

 

 

 

//see 《多处理器编程的艺术》附录A
//The Art of Multiprocessor Programming

import java.util.Random;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

//使用wait和notifyAll的队列
class CallQueue<T> {
	private int head = 0; // 读出位置
	private int tail = 0; // 写入位置
	private T[] calls;

	@SuppressWarnings("unchecked")
	public CallQueue(int capacity) {
		calls = (T[]) new Object[capacity];
	}

	public synchronized void enq(final T x) {
		while (tail - head == calls.length) {
			try {
				wait(); // 等待未满
			} catch (InterruptedException e) {

			}
		}
		calls[tail] = x;
		if (++tail == calls.length) {
			tail = 0;
		}
		notifyAll();
	}

	public synchronized T deq() {
		while (head == tail) {
			try {
				wait(); // 等待非空
			} catch (InterruptedException e) {

			}
		}
		T x = calls[head];
		if (++head == calls.length) {
			head = 0;
		}
		notifyAll();
		return x;
	}
}

// 使用ReentrantLock和Condition的队列
class LockedQueue<T> {
	private final Lock lock = new ReentrantLock();
	private final Condition notFull = lock.newCondition();
	private final Condition notEmpty = lock.newCondition();
	private final T[] items;
	private int count = 0; // 长度
	private int head = 0; // 读出位置
	private int tail = 0; // 写入位置

	@SuppressWarnings("unchecked")
	public LockedQueue(final int capacity) {
		items = (T[]) new Object[capacity];
	}

	public void enq(T x) {
		lock.lock(); // 或者使用tryLock()
		try {
			while (count == items.length) {
				notFull.await(); // 等待未满
			}
			items[tail] = x;
			if (++tail == items.length) {
				tail = 0;
			}
			++count;
			notEmpty.signal(); // 满足非空的条件
		} catch (InterruptedException e) {

		} finally {
			lock.unlock();
		}
	}

	public T deq() {
		lock.lock();
		try {
			while (count == 0) {
				notEmpty.await(); // 等待非空
			}
			T x = items[head];
			if (++head == items.length) {
				head = 0;
			}
			--count;
			notFull.signal(); // 满足未满的条件
			return x;
		} catch (InterruptedException e) {

		} finally {
			lock.unlock();
		}
		return null;
	}
}

// 线程局部对象
// Thread-Local Objects
class ThreadID {
	// 只是用于ThreadID值的递增,与线程无关(属于主线程)
	private static volatile int nextID = 0;

	// 必须重写initialValue才可以使用(实例化)ThreadLocal类
	private static class ThreadLocalID extends ThreadLocal<Integer> {
		// 每个线程对应一个ThreadID变量,而ThreadID变量间互不影响
		// 用synchronized使nextID++是原子操作
		// 所以每个ThreadID变量的值也不同
		protected synchronized Integer initialValue() {
			return nextID++;
		}
	}

	// 虽然是static,但由于继承ThreadLocal,
	// 每个引用ThreadLocalID的线程看到的静态实例将是不同的对象。
	// 而没有使用它的线程则不会创建它。
	private static ThreadLocalID threadID = new ThreadLocalID();

	public static int get() {
		return threadID.get();
	}

	// 一般不需要set,而是让ThreadLocal的initialValue来修改nextID的值
	public static void set(int index) {
		threadID.set(index);
	}
}

// 共享计数器,临界区
class Counter {
	private int value;

	public Counter(int i) {
		value = i;
	}

	// 加一,返回加一前的值
	public int getAndIncrement() {
		synchronized (this) {
			return value++;
		}
	}
}

// 显式继承Runnable,而非匿名类
class HelloWorld implements Runnable {
	String message;

	public HelloWorld(String m) {
		message = m;
	}

	public void run() {
		System.out.println(message);
	}
}

// 测试主入口
public class Test {
	// 创建线程
	public static void test1() {
		String m = "Hello World from thread";
		Thread thread = new Thread(new HelloWorld(m));
		thread.start();
		try {
			// 阻塞直至线程thread返回
			thread.join();
		} catch (InterruptedException e) {

		}

		final String message = "Hello World from thread";
		thread = new Thread(new Runnable() {
			public void run() {
				System.out.println(message);
			}
		});
		thread.start();
		try {
			// 阻塞直至线程thread返回
			thread.join();
		} catch (InterruptedException e) {

		}
	}

	// 多线程同步与本地线程对象
	public static void test2() {
		Thread[] thread = new Thread[8];
		final Counter counter = new Counter(0);
		for (int i = 0; i < thread.length; i++) {
			final String message = "Hello world from thread" + i;
			thread[i] = new Thread(new Runnable() {
				public void run() {
					System.out.println(message);
					System.out.println(">>>ThreadID:" + ThreadID.get()
							+ ", and get again:" + ThreadID.get());
					System.out.println(">>>>>locked counter:"
							+ counter.getAndIncrement());
				}
			});
		}
		for (int i = 0; i < thread.length; i++) {
			thread[i].start();
		}
		// 等待线程结束
		for (int i = 0; i < thread.length; i++) {
			try {
				thread[i].join();
			} catch (InterruptedException e) {

			}
		}
		System.out.println("done!");
	}

	// 生产者-消费者问题,双线程共享一个FIFO队列
	public static void test3() {
		final CallQueue<Integer> queue = new CallQueue<Integer>(10);
		Thread producer = new Thread(new Runnable() {
			public void run() {
				// 初始化随机种子
				Random rand = new Random(System.currentTimeMillis());
				System.out.println("producer thread start");
				for (int i = 0; i < 20; i++) {
					queue.enq(i);
					System.out.println("<< Producer put:" + i);
					try {
						Thread.sleep(rand.nextInt(100));
						// System.out.println(rand.nextInt(100));
					} catch (InterruptedException e) {

					}
				}
			}
		});
		Thread consumer = new Thread(new Runnable() {
			public void run() {
				// 初始化随机种子
				Random rand = new Random(System.currentTimeMillis() + 12345);
				System.out.println("consumer thread start");
				for (int i = 0; i < 20; i++) {
					int value = queue.deq();
					System.out.println(">> Consumer got:" + value);
					try {
						Thread.sleep(rand.nextInt(100));
						// System.out.println(rand.nextInt(100));
					} catch (InterruptedException e) {

					}
				}
			}
		});
		producer.start();
		consumer.start();
		// 阻塞直至线程返回
		try {
			producer.join();
		} catch (InterruptedException e) {

		}
		try {
			consumer.join();
		} catch (InterruptedException e) {

		}
		System.out.println("done!");
	}

	public static void main(String[] args) {
		test1();
		test2();
		test3();
	}
}

 

 

 

四、TODO:

(待续)

 

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics