#include "stdafx.h"
#include "Bmp.h"
#include "Dither.h"
#include "Exception.h"
#include "Utils/Bitwise.h"

namespace graphics {

	static Bool CODECALL bmpApplicable(IStream *from) {
		return checkHeader(from, "BM", false);
	}

	static FormatOptions *CODECALL bmpCreate(ImageFormat *f) {
		return new (f) BMPOptions();
	}

	BMPOptions::BMPOptions() : mode(color24) {}

	BMPOptions::BMPOptions(Mode mode) : mode(mode) {}

	void BMPOptions::toS(StrBuf *to) const {
		*to << S("BMP: { ");
		switch (mode) {
		case unknown:
			*to << S("unknown");
			break;
		case mono1:
			*to << S("monochrome");
			break;
		case palette4:
			*to << S("4-bit palette");
			break;
		case palette8:
			*to << S("8-bit palette");
			break;
		case color16alpha:
			*to << S("16-bit color, with alpha");
			break;
		case color16:
			*to << S("16-bit color");
			break;
		case color24:
			*to << S("24-bit color");
			break;
		case color24alpha:
			*to << S("24-bit color, with 8-bit alpha");
			break;
		}
		*to << S(" }");
	}

	ImageFormat *bmpFormat(Engine &e) {
		const wchar *exts[] = {
			S("bmp"),
			S("dib"),
			null
		};
		return new (e) ImageFormat(S("Bitmap"), exts, &bmpApplicable, &bmpCreate);
	}



	/**
	 * Structs used when parsing a Windows BMP/DIB file.
	 *
	 * NOTE: Numbers are stored in little endian format, which works fine as long as we are running
	 * on an X86 cpu.
	 */

	/**
	 * Bitmap header, excluding the 2-byte type field containing "BM", since that would make the
	 * rest of the fields misaligned.
	 */
	struct FileHeader {
		// File size in bytes.
		nat size;

		// Reserved. Should be 0.
		nat reserved;

		// Offset to start of pixel data.
		nat pixelOffset;
	};


	/**
	 * Description of the image. Located directly after FileHeader.
	 */
	struct ImageHeader {
		// Header size in bytes. May be larger than this header.
		nat size;

		// Image width in pixels.
		nat width;

		// Image height in pixels.
		nat height;

		// Number of planes. Must be 1.
		nat16 planes;

		// Number of bits per pixel. 1, 4, 8, 16, 24, or 32.
		nat16 pixelDepth;

		// Compression type (0 = not compressed).
		nat compression;

		// Image size in bytes. Possibly zero for uncompressed images.
		nat imageSize;

		// Resolution in pixels per meter.
		nat xResolution;
		nat yResolution;

		// Number of color map entries used.
		nat colorsUsed;

		// Number of significant colors.
		nat colorsImportant;
	};


	/**
	 * An entry in the color table. Located after ImageHeader.
	 */
	struct ImageColor {
		byte b;
		byte g;
		byte r;
		byte pad;
	};


	/**
	 * Helper for decoding bitfields.
	 */
	class Bitfield {
	public:
		explicit Bitfield(Nat mask) : shift(0), width(0), scale(0), mask(mask) {
			if (mask == 0)
				return;

			// Put the mask aligned at bit 0.
			for (Nat tmp = mask; (tmp & 0x1) == 0; tmp >>= 1)
				shift++;

			// Find where the mask ends.
			for (Nat tmp = mask >> shift; tmp; tmp >>= 1)
				width++;

			// Compute 'mask'.
			for (Nat bits = 0; bits < 8; bits += width)
				scale = (scale << width) | 0x1;

			// Fix 'mask' so that it is properly aligned to the final 16-bit shift.
			scale <<= (16 + 8) - roundUp(Nat(8), width);
		}

		Byte decode(Nat value) {
			if (mask == 0)
				return 255;

			value &= mask;
			value >>= shift;
			value *= scale;
			value >>= 16;
			return Byte(value);
		}

		Nat encode(Byte src) {
			if (mask == 0)
				return 0;

			Nat val = src;
			val <<= shift;
			val >>= 8 - width;
			val &= mask;
			return val;
		}

	private:
		Nat shift;
		Nat width;
		Nat scale;
		Nat mask;
	};


	// Fill a structure with data from a stream.
	template <class T>
	static bool fill(IStream *src, T &out) {
		GcPreArray<byte, sizeof(T)> data;
		Buffer r = src->fill(emptyBuffer(data));
		if (r.filled() != sizeof(T))
			return false;
		memcpy(&out, r.dataPtr(), sizeof(T));
		return true;
	}

	// Read data to an array.
	template <class T>
	T *read(IStream *src, Nat count) {
		Nat size = count*sizeof(T);
		Buffer r = src->fill(buffer(src->engine(), size));
		if (r.filled() != size)
			return null;
		return (T *)r.dataPtr();
	}

	// Decode various bit depths.
	static bool decode1(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);
	static bool decode4(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);
	static bool decode8(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);
	static bool decode16(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);
	static bool decode24(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);
	static bool decode32(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);

	// Choose a good decoder.
	typedef bool (*Decoder)(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart);
	static Decoder pickDecoder(const ImageHeader &header);

	Image *BMPOptions::load(IStream *from) {
		// Keep track of file offset.
		Nat position = 0;
		const wchar *error = S("");

		{
			Buffer h = from->fill(2);
			error = S("Invalid BMP header.");
			if (h.filled() != 2)
				throw new (this) ImageLoadError(error);
			if (h[0] != 'B' || h[1] != 'M')
				throw new (this) ImageLoadError(error);

			position += 2;
		}

		// Read the rest of the header.
		FileHeader header;
		error = S("Invalid or incomplete BMP header.");
		if (!fill(from, header))
			throw new (this) ImageLoadError(error);
		position += sizeof(FileHeader);

		ImageHeader image;
		if (!fill(from, image))
			throw new (this) ImageLoadError(error);
		position += sizeof(ImageHeader);
		if (image.size < 40)
			throw new (this) ImageLoadError(error);
		if (image.planes != 1)
			throw new (this) ImageLoadError(error);

		Decoder decode = pickDecoder(image);
		error = S("Unsupported bit depth in the image.");
		if (!decode)
			throw new (this) ImageLoadError(error);

		// Now, we can create the output image and start writing to it.
		Image *result = new (from) Image(image.width, image.height);
		error = S("Failed reading the image.");

		Nat remaining = header.pixelOffset - position;
		if ((*decode)(this, from, result, image, remaining)) {
			return result;
		}

		throw new (this) ImageLoadError(error);
	}

	// Pick a decoder.
	static Decoder pickDecoder(const ImageHeader &header) {
		switch (header.pixelDepth) {
		case 1:
			if (header.compression != 0)
				return null;
			return &decode1;
		case 4:
			if (header.compression != 0)
				return null;
			return &decode4;
		case 8:
			if (header.compression != 0)
				return null;
			return &decode8;
		case 16:
			if (header.compression != 3)
				return null;
			return &decode16;
		case 24:
			if (header.compression != 0)
				return null;
			return &decode24;
		case 32:
			if (header.compression != 3)
				return null;
			return &decode32;
		}

		return null;
	}

	static bool decode32(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart) {
		opt->mode = BMPOptions::color24alpha;

		Nat w = to->width();
		Nat h = to->height();

		// Read bitfields.
		Nat r = 0, g = 0, b = 0, a = 0;
		if (!fill(from, r) || !fill(from, g) || !fill(from, b))
			return false;

		untilStart -= 3*sizeof(Nat);

		// Is there an alpha channel?
		if (header.size >= sizeof(ImageHeader) + 3*sizeof(Nat)) {
			// Probably, yes.
			if (!fill(from, a))
				return false;
			untilStart -= sizeof(Nat);
		}

		Bitfield rBit(r);
		Bitfield gBit(g);
		Bitfield bBit(b);
		Bitfield aBit(a);
		from->fill(untilStart);

		Nat stride = w*4;
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				Nat px = src[x*4 + 0];
				px |= Nat(src[x*4 + 1]) << 8;
				px |= Nat(src[x*4 + 2]) << 16;
				px |= Nat(src[x*4 + 3]) << 24;

				dest[4*x + 0] = rBit.decode(px);
				dest[4*x + 1] = gBit.decode(px);
				dest[4*x + 2] = bBit.decode(px);
				dest[4*x + 3] = aBit.decode(px);
			}
		}

		return true;
	}

	static bool decode24(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart) {
		opt->mode = BMPOptions::color24;

		Nat w = to->width();
		Nat h = to->height();

		// Skip until the start of the file.
		from->fill(untilStart);

		Nat stride = roundUp(w*3, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				dest[4*x + 0] = src[3*x + 2];
				dest[4*x + 1] = src[3*x + 1];
				dest[4*x + 2] = src[3*x + 0];
				dest[4*x + 3] = 255;
			}
		}

		return true;
	}


	static bool decode16(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart) {
		Nat w = to->width();
		Nat h = to->height();

		// Read bitfields.
		Nat r = 0, g = 0, b = 0, a = 0;
		if (!fill(from, r) || !fill(from, g) || !fill(from, b))
			return false;

		// Is there an alpha channel?
		if (header.size >= sizeof(ImageHeader) + 3*sizeof(Nat)) {
			// Probably, yes.
			if (!fill(from, a))
				return false;
			untilStart -= sizeof(Nat);
		}

		// Figure out which format it is:
		Nat rBits = setBitCount(r);
		Nat gBits = setBitCount(g);
		Nat bBits = setBitCount(b);
		if (rBits == 5 && gBits == 5 && bBits == 5) {
			opt->mode = BMPOptions::color16alpha;
		} else if (rBits == 5 && gBits == 6 && bBits == 5) {
			opt->mode = BMPOptions::color16;
		} else {
			opt->mode = BMPOptions::unknown;
		}

		Bitfield rBit(r);
		Bitfield gBit(g);
		Bitfield bBit(b);
		Bitfield aBit(a);

		untilStart -= 3*sizeof(Nat);
		from->fill(untilStart);

		Nat stride = roundUp(w*2, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->fill(src);
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				Nat px = src[x*2 + 0];
				px |= Nat(src[x*2 + 1]) << 8;

				dest[4*x + 0] = rBit.decode(px);
				dest[4*x + 1] = gBit.decode(px);
				dest[4*x + 2] = bBit.decode(px);
				dest[4*x + 3] = aBit.decode(px);
			}
		}

		return true;
	}

	static bool decode8(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart) {
		opt->mode = BMPOptions::palette8;

		Nat w = to->width();
		Nat h = to->height();

		Nat used = header.colorsUsed;
		if (used == 0)
			used = 256;

		ImageColor *palette = read<ImageColor>(from, used);
		if (!palette)
			return false;

		untilStart -= used * sizeof(ImageColor);
		from->fill(untilStart);

		Nat stride = roundUp(w, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->fill(src);
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x];
				if (color >= used)
					color = used; // Pick a color so that we do not crash.
				ImageColor *c = &palette[color];

				dest[4*x + 0] = c->r;
				dest[4*x + 1] = c->g;
				dest[4*x + 2] = c->b;
				dest[4*x + 3] = 255;
			}
		}

		return true;
	}

	static bool decode4(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart) {
		opt->mode = BMPOptions::palette4;

		Nat w = to->width();
		Nat h = to->height();

		Nat used = header.colorsUsed;
		if (used == 0)
			used = 16;

		ImageColor *palette = read<ImageColor>(from, used);
		if (!palette)
			return false;

		untilStart -= used * sizeof(ImageColor);
		from->fill(untilStart);

		Nat stride = roundUp((w + 1)/2, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->fill(src);
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x / 2];
				color = (color >> (~x & 0x1)*4) & 0xF;
				if (color >= used)
					color = used; // Pick a color so that we do not crash.
				ImageColor *c = &palette[color];

				dest[4*x + 0] = c->r;
				dest[4*x + 1] = c->g;
				dest[4*x + 2] = c->b;
				dest[4*x + 3] = 255;
			}
		}

		return true;
	}

	static bool decode1(BMPOptions *opt, IStream *from, Image *to, const ImageHeader &header, Nat untilStart) {
		opt->mode = BMPOptions::mono1;

		Nat w = to->width();
		Nat h = to->height();

		Nat used = header.colorsUsed;
		if (used == 0)
			used = 2;
		ImageColor *palette = read<ImageColor>(from, used);

		untilStart -= used * sizeof(ImageColor);
		from->fill(untilStart);

		Nat stride = roundUp((w + 7)/8, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->fill(src);
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x / 8];
				color = (color >> (7 - (x & 0x7))) & 0x1;

				dest[4*x + 0] = palette[color].r;
				dest[4*x + 1] = palette[color].g;
				dest[4*x + 2] = palette[color].b;
				dest[4*x + 3] = 255;
			}
		}

		return true;
	}

	static void putFileHeader(OStream *to, nat contentSize, nat paletteEntries) {
		// Output the BM header.
		to->write(buffer(to->engine(), (const Byte *)"BM", 2));

		FileHeader header = {
			0,
			0,
			nat(2 + sizeof(FileHeader) + sizeof(ImageHeader) + paletteEntries * sizeof(ImageColor))
		};
		header.size = header.pixelOffset + contentSize;
		to->write(buffer(to->engine(), (const Byte *)&header, sizeof(FileHeader)));
	}

	static ImageHeader defaultImageHeader(Image *image) {
		ImageHeader header = {
			sizeof(ImageHeader),
			image->width(), image->height(),
			1,
			1,
			0,
			0, /* size, we could compute this */
			2835, 2835,
			0,
			0,
		};
		return header;
	}

	static void putImageHeader(OStream *to, const ImageHeader &header) {
		to->write(buffer(to->engine(), (const Byte *)&header, sizeof(ImageHeader)));
	}

	static void putBitmasks(OStream *to, Nat r, Nat g, Nat b, Nat a) {
		Nat data[4] = { r, g, b, a };
		to->write(buffer(to->engine(), (const Byte *)data, sizeof(Nat) * 4));
	}

	static void putPalette(OStream *to, byte r, byte g, byte b) {
		ImageColor c = { r, g, b, 0 };
		to->write(buffer(to->engine(), (const Byte *)&c, sizeof(ImageColor)));
	}

	static void encode1(Image *image, OStream *to) {
		// A single scanline.
		Nat stride = roundUp((image->width() + 7) / 8, Nat(4));
		Buffer buffer = storm::buffer(image->engine(), stride);
		buffer.filled(buffer.count());

		putFileHeader(to, stride * image->height(), 2);

		ImageHeader header = defaultImageHeader(image);
		header.pixelDepth = 1;
		header.colorsUsed = 2;
		putImageHeader(to, header);

		putPalette(to, 0, 0, 0);
		putPalette(to, 255, 255, 255);

		DitherState dither;

		// Output scanlines:
		for (Nat y = image->height(); y > 0; y--) {
			memset(buffer.dataPtr(), 0, buffer.count());
			for (Nat x = 0; x < image->width(); x++) {
				Nat index = x / 8;
				Nat bit = x % 8;

				if (dither.pixelValue(fromLinear(image->get(x, y - 1).toLinear().brightness())))
					buffer[index] |= 1 << (7 - bit);
			}

			to->write(buffer);
		}
	}

	static void encode8(Image *image, OStream *to) {
		// A single scanline.
		Nat stride = roundUp(image->width(), Nat(4));
		Buffer buffer = storm::buffer(image->engine(), stride);
		buffer.filled(buffer.count());

		Nat palette = 128;
		putFileHeader(to, stride * image->height(), palette);

		ImageHeader header = defaultImageHeader(image);
		header.pixelDepth = 8;
		header.colorsUsed = palette;
		putImageHeader(to, header);

		for (Nat r = 0; r < 4; r++) {
			for (Nat g = 0; g < 8; g++) {
				for (Nat b = 0; b < 4; b++) {
					putPalette(to, r | (r << 2) | (r << 4) | (r << 6),
							(g << 5) | (g << 2) | (g >> 1),
							b | (b << 2) | (b << 4) | (b << 6));
				}
			}
		}

		// Output scanlines:
		for (Nat y = image->height(); y > 0; y--) {
			for (Nat x = 0; x < image->width(); x++) {
				Byte *src = image->buffer(x, y - 1);
				Byte index = 0;
				index |= (src[0] >> 1) & 0x60;
				index |= (src[1] >> 3) & 0x1C;
				index |= (src[2] >> 6) & 0x03;

				buffer[x] = index;
			}

			to->write(buffer);
		}
	}

	static void encode16(Image *image, OStream *to, Nat r, Nat g, Nat b, Nat a) {
		// A single scanline.
		Nat stride = roundUp(image->width() * 2, Nat(4));
		Buffer buffer = storm::buffer(image->engine(), stride);
		buffer.filled(buffer.count());

		putFileHeader(to, stride * image->height(), 4);

		ImageHeader header = defaultImageHeader(image);
		header.size += 4*sizeof(Nat);
		header.pixelDepth = 16;
		header.compression = 3;
		putImageHeader(to, header);
		putBitmasks(to, r, g, b, a);

		Bitfield rBit(r);
		Bitfield gBit(g);
		Bitfield bBit(b);
		Bitfield aBit(a);

		rBit.encode(0);
		gBit.encode(0);
		bBit.encode(0);
		aBit.encode(0);

		// Output scanlines:
		for (Nat y = image->height(); y > 0; y--) {
			for (Nat x = 0; x < image->width(); x++) {
				Byte *src = image->buffer(x, y - 1);
				Nat pixel = 0;
				pixel |= rBit.encode(src[0]);
				pixel |= gBit.encode(src[1]);
				pixel |= bBit.encode(src[2]);
				pixel |= aBit.encode(src[3]);
				buffer[x*2] = Byte(pixel & 0xFF);
				buffer[x*2 + 1] = Byte(pixel >> 8);
			}

			to->write(buffer);
		}
	}

	static void encode24(Image *image, OStream *to) {
		// A single scanline.
		Nat stride = roundUp(image->width() * 3, Nat(4));
		Buffer buffer = storm::buffer(image->engine(), stride);
		buffer.filled(buffer.count());

		putFileHeader(to, stride * image->height(), 0);

		ImageHeader header = defaultImageHeader(image);
		header.pixelDepth = 24;
		header.compression = 0;
		putImageHeader(to, header);

		// Output scanlines:
		for (Nat y = image->height(); y > 0; y--) {
			Nat pos = 0;
			for (Nat x = 0; x < image->width(); x++) {
				Byte *src = image->buffer(x, y - 1);
				buffer[pos++] = src[2];
				buffer[pos++] = src[1];
				buffer[pos++] = src[0];
			}

			to->write(buffer);
		}
	}

	static void encode32(Image *image, OStream *to) {
		// A single scanline.
		Nat stride = roundUp(image->width() * 4, Nat(4));
		Buffer buffer = storm::buffer(image->engine(), stride);
		buffer.filled(buffer.count());

		putFileHeader(to, stride * image->height(), 4);

		ImageHeader header = defaultImageHeader(image);
		header.size += 4*sizeof(Nat);
		header.pixelDepth = 32;
		header.compression = 3;
		putImageHeader(to, header);
		putBitmasks(to, 0xFF0000, 0x00FF00, 0x0000FF, 0xFF000000);

		// Output scanlines:
		for (Nat y = image->height(); y > 0; y--) {
			Nat pos = 0;
			for (Nat x = 0; x < image->width(); x++) {
				Byte *src = image->buffer(x, y - 1);
				buffer[pos++] = src[2];
				buffer[pos++] = src[1];
				buffer[pos++] = src[0];
				buffer[pos++] = src[3];
			}

			to->write(buffer);
		}
	}

	void BMPOptions::save(Image *image, OStream *to) {
		mode = color16alpha;

		switch (mode) {
		case unknown:
		case palette4:
			throw new (this) ImageSaveError(TO_S(this, S("Unsupported output format: ") << *this));
		case mono1:
			return encode1(image, to);
		case palette8:
			return encode8(image, to);
		case color16alpha:
			return encode16(image, to, 0xF800, 0x07C0, 0x003E, 0x0001);
		case color16:
			return encode16(image, to, 0xF800, 0x07E0, 0x001F, 0x0000);
		case color24:
			return encode24(image, to);
		case color24alpha:
			return encode32(image, to);
		}
	}

}
