use lang:asm;
use core:asm;
use core:sync;
use progvis:lang:cpp;
use progvis:lang:cpp:ptr;
use progvis:lang:cpp:impl;
use lang:bs:unsafe;


/**
 * Pintos-like semaphore.
 */
value semaphore {
	// Note: Layout here is important!
	SemaImpl? impl;
	progvis:program:MaybePositive count;

	init() {
		init {
			// Initializes to "nothing".
			count();
		}
	}
}

// Check so that a pointer is valid.
void checkSyncPtr(Ptr<semaphore> sema) {
	Size sz = sPtr + sInt;
	checkPtr(sema.base, sema.offset, sz.current);
}

void sema_init(Ptr<semaphore> sema, Int value) {
	if (value < 0)
		throw SyncError("Cannot initialize a semaphore with a negative value!");

	SemaImpl s(value.nat);

	checkSyncPtr(sema);

	Nat offset = sema.offset - sPtr.current * 2;
	if (sema.base.readPtr(offset).any)
		throw SyncError("Tried to initialize a semaphore multiple times!");

	sema.base.writePtr(offset, unsafe:RawPtr(s));
	sema.base.writeInt(offset + sPtr.current, value);
}

void sema_destroy(Ptr<semaphore> sema) {
	checkSyncPtr(sema);

	Nat offset = sema.offset - sPtr.current * 2;
	sema.base.writePtr(offset, unsafe:RawPtr());
	sema.base.writeInt(offset + sPtr.current, -1);
}

void sema_up(Ptr<semaphore> sema) : cppBarrier(release) {
	checkSyncPtr(sema);

	Nat offset = sema.offset - sPtr.current * 2;
	var sem = sema.base.readPtr(offset);
	if (s = sem.asObject() as SemaImpl) {
		offset += sPtr.current;
		sema.base.writeInt(offset, sema.base.readInt(offset) + 1);
		s.up();
		return;
	}

	throw SyncError("Can not operate on an uninitialized semaphore!");
}

void sema_down(Ptr<semaphore> sema) : cppBarrier(acquire) {
	checkSyncPtr(sema);

	Nat offset = sema.offset - sPtr.current * 2;

	var sem = sema.base.readPtr(offset);
	if (s = sem.asObject() as SemaImpl) {
		offset += sPtr.current;
		s.down(Variant(sema));
		sema.base.writeInt(offset, sema.base.readInt(offset) - 1);
		return;
	}

	throw SyncError("Can not operate on an uninitialized semaphore!");
}

/**
 * Pintos-like lock.
 */
value lock {
	// Note: Layout here is important!
	SemaImpl? impl;
	progvis:program:ThreadId held_by;

	init() {
		init { held_by = progvis:program:ThreadId:noInit(); }
	}
}

// Check so that a pointer is valid.
void checkSyncPtr(Ptr<lock> lock) {
	Size sz = sPtr + sInt;
	checkPtr(lock.base, lock.offset, sz.current);
}

void lock_init(Ptr<lock> lock) {
	SemaImpl s(1);

	checkSyncPtr(lock);

	Nat offset = lock.offset - sPtr.current * 2;
	if (lock.base.readPtr(offset).any)
		throw SyncError("Tried to initialize a lock multiple times!");

	lock.base.writePtr(offset, unsafe:RawPtr(s));
	lock.base.writeInt(offset + sPtr.current, 0);
}

void lock_destroy(Ptr<lock> lock) {
	checkSyncPtr(lock);

	Nat offset = lock.offset - sPtr.current * 2;
	lock.base.writePtr(offset, unsafe:RawPtr());
	lock.base.writeNat(offset + sPtr.current, progvis:program:ThreadId:noInit.v);
}

void lock_acquire(Ptr<lock> lock) : cppBarrier(acquire) {
	checkSyncPtr(lock);

	Nat offset = lock.offset - sPtr.current * 2;

	var implRaw = lock.base.readPtr(offset);
	if (l = implRaw.asObject() as SemaImpl) {
		offset += sPtr.current;

		Nat myId = progvis:program:findThisThreadId();
		if (lock.base.readNat(offset) == myId)
			throw SyncError("Trying to acquire a lock we're already holding!");

		l.down(Variant(lock));

		lock.base.writeNat(offset, myId);
		return;
	}

	throw SyncError("Can not operate on an uninitialized lock!");
}

void lock_release(Ptr<lock> lock) : cppBarrier(release) {
	checkSyncPtr(lock);

	Nat offset = lock.offset - sPtr.current * 2;

	var implRaw = lock.base.readPtr(offset);
	if (l = implRaw.asObject() as SemaImpl) {
		offset += sPtr.current;

		Nat myId = progvis:program:findThisThreadId();
		if (lock.base.readNat(offset) != myId)
			throw SyncError("Trying to release a lock we're not holding!");

		lock.base.writeInt(offset, 0);

		l.up();
		return;
	}

	throw SyncError("Can not operate on an uninitialized lock!");
}


/**
 * Pintos-like condition variable.
 */
value condition {
	CondImpl? impl;
	progvis:program:MaybePositive waiting;

	init() {
		init {
			// Initialized to "nothing".
			waiting();
		}
	}
}

// Check so that a pointer is valid.
void checkSyncPtr(Ptr<condition> cond) {
	Size sz = sPtr + sInt;
	checkPtr(cond.base, cond.offset, sz.current);
}

void cond_init(Ptr<condition> cond) {
	CondImpl c;

	checkSyncPtr(cond);

	Nat offset = cond.offset - sPtr.current * 2;
	if (cond.base.readPtr(offset).any)
		throw SyncError("Tried to initialize a condition variable multiple times!");

	cond.base.writePtr(offset, unsafe:RawPtr(c));
	cond.base.writeInt(offset + sPtr.current, 0);
}

void cond_destroy(Ptr<condition> cond) {
	checkSyncPtr(cond);

	Nat offset = cond.offset - sPtr.current * 2;
	cond.base.writePtr(offset, unsafe:RawPtr());
	cond.base.writeInt(offset + sPtr.current, -1);
}

void cond_wait(Ptr<condition> cond, Ptr<lock> lock) : cppBarrier(full) {
	checkSyncPtr(cond);
	checkSyncPtr(lock);

	Nat cOffset = cond.offset - sPtr.current * 2;
	Nat lOffset = lock.offset - sPtr.current * 2;

	Nat ourId = progvis:program:findThisThreadId();
	if (lock.base.readNat(lOffset + sPtr.current) != ourId)
		throw SyncError("The lock passed to 'cond_wait' must be held by the current thread!");

	if (cImpl = cond.base.readPtr(cOffset).asObject() as CondImpl) {
		if (lImpl = lock.base.readPtr(lOffset).asObject() as SemaImpl) {
			// We're not holding the lock anymore.
			lock.base.writeNat(lOffset + sPtr.current, 0);

			cond.base.writeInt(cOffset + sPtr.current, cImpl.waitingCount + 1);
			cImpl.wait(Variant(cond), Variant(lock), lImpl);

			// We're holding the lock now!
			lock.base.writeNat(lOffset + sPtr.current, ourId);
			return;
		}
	}

	throw SyncError("Can not operate on an uninitialized condition and/or lock!");
}

void cond_signal(Ptr<condition> cond, Ptr<lock> lock) {
	checkSyncPtr(cond);
	checkSyncPtr(lock);

	Nat cOffset = cond.offset - sPtr.current * 2;
	Nat lOffset = lock.offset - sPtr.current * 2;

	if (lock.base.readNat(lOffset + sPtr.current) != progvis:program:findThisThreadId())
		throw SyncError("The lock passed to 'cond_signal' must be held by the current thread!");

	if (cImpl = cond.base.readPtr(cOffset).asObject() as CondImpl) {
		if (lImpl = lock.base.readPtr(lOffset).asObject() as SemaImpl) {
			cImpl.signal(lImpl);
			cond.base.writeInt(cOffset + sPtr.current, cImpl.waitingCount);
			return;
		}
	}

	throw SyncError("Can not operate on an uninitialized condition and/or lock!");
}

void cond_broadcast(Ptr<condition> cond, Ptr<lock> lock) {
	checkSyncPtr(cond);
	checkSyncPtr(lock);

	Nat cOffset = cond.offset - sPtr.current * 2;
	Nat lOffset = lock.offset - sPtr.current * 2;

	if (lock.base.readNat(lOffset + sPtr.current) != progvis:program:findThisThreadId())
		throw SyncError("The lock passed to 'cond_broadcast' must be held by the current thread!");

	if (cImpl = cond.base.readPtr(cOffset).asObject() as CondImpl) {
		if (lImpl = lock.base.readPtr(lOffset).asObject() as SemaImpl) {
			cImpl.broadcast(lImpl);
			cond.base.writeInt(cOffset + sPtr.current, cImpl.waitingCount);
			return;
		}
	}

	throw SyncError("Can not operate on an uninitialized condition and/or lock!");
}

// Function "prevent_optimization" for compatibility with C library.
void prevent_optimization() {}
