#include <stdio.h>
#include <string.h>
#include <inttypes.h>
#include <limits.h>
#include <assert.h>
#include <math.h>

#ifndef NDEBUG
	#define BPE_NDEBUG
	#define NDEBUG
#endif
#include <algorithm>
#ifdef BPE_NDEBUG
	#undef BPE_NDEBUG
	#undef NDEBUG
#endif

#ifndef BPE_DEBUG
	#define BPE_DEBUG 0
#endif

static void print_char(uint8_t c) { if((' '<=c) && (0x7F>c)) putchar(c); else printf("%02x",c); }

// 21FF,6 = 55323338 in 35secs
// 21FF,7 = 55346806 in 35secs
// 21FF,5 = 55319589 in 37secs
// 21FF,4 = 55318241 in 39secs
// 2300,7 = 55315661 in 34secs
// 2500,7 = 55292688 in 35secs
// 2800,8 = 55312826 in 34secs
// 2800,4 = 55289065 in 38secs

enum {
	BLOCK_SIZE = 0x2800,
	BUF_BYTES = 2,
	MIN_FREQ = 4,
	BYTE = (1<<CHAR_BIT),
};

typedef uint16_t count_t;

struct digraph_t {
	uint16_t di;
	count_t count;
	uint8_t gen;
	static inline bool cmp(const digraph_t& a,const digraph_t& b) {
		return (a.count > b.count); 
	}
};

static inline uint8_t HI(uint16_t di) { return (di>>CHAR_BIT); }
static inline uint8_t LO(uint16_t di) { return di; }
static inline uint16_t DI(uint8_t hi,uint8_t lo) { return ((hi<<CHAR_BIT)|lo); }

static inline void inc_digraph(uint16_t di,count_t digraphs[],digraph_t* dilist,count_t& dilist_len,uint8_t gen) {
	if(MIN_FREQ > digraphs[di]) {
		digraphs[di]++;
		if(MIN_FREQ == digraphs[di]) {
			dilist[dilist_len].di = di;
			dilist[dilist_len].count = MIN_FREQ-1;
			digraphs[di] = MIN_FREQ+dilist_len;
			dilist_len++;
		} else
			return;
	}
	assert(dilist[digraphs[di]-MIN_FREQ].di == di);
	dilist[digraphs[di]-MIN_FREQ].count++;
	dilist[digraphs[di]-MIN_FREQ].gen = gen;
}

static void wsort(digraph_t* dilist,count_t dilist_len,uint8_t gen) {
	static digraph_t ordered[BLOCK_SIZE], noisy[BLOCK_SIZE];
	size_t ordered_len = 0, noisy_len = 0;
	// put them into the two lists
	for(size_t i=0; i<dilist_len; i++) {
		if(gen == dilist[i].gen)
			noisy[noisy_len++] = dilist[i];
		else
			ordered[ordered_len++] = dilist[i];
	}
	assert((noisy_len+ordered_len) == dilist_len);
	// sort the noisy list
	if(255 == gen) {
		std::sort(noisy,noisy+noisy_len,digraph_t::cmp);
	} else {
#if 0
		for(size_t inc = ((noisy_len>>1)|(noisy_len&1)); inc; inc = round((double)inc/2.2)) {
			for(size_t i = inc; i<noisy_len; i++) {
				const digraph_t tmp = noisy[i];
				size_t j = i;
				while((j>=inc) && digraph_t::cmp(tmp,noisy[j-inc])) {
					noisy[j] = noisy[j-inc];
					j -= inc;
				}
				noisy[j] = tmp;
			}
		}
#elif 1
		for(size_t i=1; i<noisy_len; i++) {
			const digraph_t value = noisy[i];
			int j = i-1;
			while((j>=0) && digraph_t::cmp(value,noisy[j])) {
				noisy[j+1] = noisy[j];
			j--;
			}
			noisy[j+1] = value;
		}
#else
		std::sort(noisy,noisy+noisy_len,digraph_t::cmp);
#endif
	}
#if BPE_DEBUG >= 4
	for(size_t i=0; i<noisy_len; i++) {
		printf("%3zu,%2u [",i,noisy[i].gen);
		print_char(HI(noisy[i].di));
		putchar(':');
		print_char(LO(noisy[i].di));
		printf("] = %d\n",(int)noisy[i].count);
	}
#endif
	// merge the two lists
	size_t ordered_pos = 0, noisy_pos = 0;
	for(size_t i=0; i<dilist_len; i++) {
		if(ordered_pos < ordered_len) {
			if((noisy_pos < noisy_len) && digraph_t::cmp(noisy[noisy_pos],ordered[ordered_pos]))
				dilist[i] = noisy[noisy_pos++];
			else
				dilist[i] = ordered[ordered_pos++];
		} else {
			dilist[i] = noisy[noisy_pos++];
		}
	}
	assert(noisy_pos == noisy_len);
	assert(ordered_pos == ordered_len);
}

static inline void dec_digraph(uint16_t di,count_t digraphs[],digraph_t* dilist,count_t& dilist_len,uint8_t gen) {
#if BPE_DEBUG >= 4
	printf("decrementing ");
	print_char(HI(di));
	putchar(':');
	print_char(LO(di));
	putchar('\n');
#endif
	if(MIN_FREQ <= digraphs[di]) {
		assert(dilist[digraphs[di]-MIN_FREQ].di == di);
		dilist[digraphs[di]-MIN_FREQ].count--;
		dilist[digraphs[di]-MIN_FREQ].gen = gen;
	} else {
		assert(digraphs[di]);
		digraphs[di]--;
	}
}

int main(int argc,char** args) {
	if(4!=argc) {
		fprintf(stderr,"usage: bpe [c|d] src dest\n");
		return -1;
	}
	FILE* srcf = fopen(args[2],"rb");
	FILE* destf = fopen(args[3],"wb");
	uint8_t buffer[2][BLOCK_SIZE], *src = buffer[0], *dest = buffer[1];
	// what to do?
	if('c'==*args[1]) {
		// compress; lets be greedy
		count_t digraphs[BYTE*BYTE];
		count_t used[BYTE];
		digraph_t dilist[BLOCK_SIZE];
		size_t block_num = 0, read = 0, written = 0;
		while(true) {
			size_t block_len = fread(src,1,BLOCK_SIZE,srcf);
			if(!block_len)
				break;
			read += block_len;
			block_num++;
			count_t dilist_len = 0;
			memset(digraphs,0,sizeof(digraphs));
			memset(used,0,sizeof(used));
			used[src[0]]++;
			for(size_t i=1; i<block_len; i++) {
				used[src[i]]++;
				inc_digraph(DI(src[i-1],src[i]),digraphs,dilist,dilist_len,-1);
			}
			size_t cycle;
			for(cycle=0; (cycle<BYTE); cycle++) {
				wsort(dilist,dilist_len,cycle-1);
#if BPE_DEBUG >= 1
				printf("%5zu,%2zu = %d\n",block_num,cycle,dilist_len);
				for(size_t i=0; i<dilist_len; i++) {
					printf("%3zu,%2u [",i,dilist[i].gen);
					print_char(HI(dilist[i].di));
					putchar(':');
					print_char(LO(dilist[i].di));
					printf("] = %d\n",(int)dilist[i].count);
				}
#endif
				for(size_t i=1; i<dilist_len; i++)
					assert(dilist[i].count <= dilist[i-1].count);
				const size_t prev_len = dilist_len;
				for(size_t i=0; i<dilist_len; i++) {
					if(dilist[i].count < MIN_FREQ) {
						dilist_len = i;
						break;
					}
					digraphs[dilist[i].di] = MIN_FREQ+i;
				}
				if(!dilist_len)
					break;
				for(size_t i=dilist_len; i<prev_len; i++) {
					assert(dilist[i].count < MIN_FREQ);
					digraphs[dilist[i].di] = dilist[i].count;
				}
				// work out what we're replacing the digraph with
				count_t replace_idx = 0;
				while((replace_idx < BYTE) && used[replace_idx])
					replace_idx++;
				if(replace_idx == BYTE)
					break;
				const uint8_t hi = HI(dilist->di), lo = LO(dilist->di);
#if BPE_DEBUG >= 2
				printf("%5zu,%3zu [",block_num,cycle);
				print_char(hi);
				putchar(':');
				print_char(lo);
				printf("] = %d ",(int)dilist->count);
				print_char(replace_idx);
				putchar('\n');
#endif
				// write header
				uint8_t* out = dest;
				*out++ = hi;
				*out++ = lo;
				*out++ = replace_idx;
				used[hi]++;
				used[lo]++;
				used[replace_idx]++;
				inc_digraph(dilist->di,digraphs,dilist,dilist_len,cycle);
				inc_digraph(DI(lo,replace_idx),digraphs,dilist,dilist_len,cycle);
				inc_digraph(DI(replace_idx,src[0]),digraphs,dilist,dilist_len,cycle);
				// write body
				for(size_t i=0; i<block_len; ) {
					if((src[i] == hi) && (i<(block_len-1)) && (src[i+1] == lo)) {
#if BPE_DEBUG>=3
						printf("%zu read ",out-dest);
						print_char(src[i]);
						putchar(':');
						print_char(src[i+1]);
						printf(", write ");
						print_char(replace_idx);
						putchar('\n');
#endif
						used[hi]--;
						used[lo]--;
						used[replace_idx]++;
						dec_digraph(dilist->di,digraphs,dilist,dilist_len,cycle);
						dec_digraph(DI(out[-1],hi),digraphs,dilist,dilist_len,cycle);
						inc_digraph(DI(out[-1],replace_idx),digraphs,dilist,dilist_len,cycle);
						if(i < (block_len-2)) {
							dec_digraph(DI(lo,src[i+2]),digraphs,dilist,dilist_len,cycle);
							inc_digraph(DI(replace_idx,src[i+2]),digraphs,dilist,dilist_len,cycle);
						}
						*out++ = replace_idx;
						i += 2;
					} else {
#if BPE_DEBUG>=4
						printf("%zu writing ",out-dest);
						print_char(src[i]);
						putchar('\n');
#endif
						*out++ = src[i++];
					}
				}
				assert(dilist->count == 1);
				block_len = (out-dest);
				std::swap(src,dest);
			}
			
			for(int i=0; i<BUF_BYTES; i++) {
				fputc((block_len>>((BUF_BYTES-i-1)*8))&0xff,destf);
				written++;
			}
			fputc(cycle,destf);
			written++;
			fwrite(src,1,block_len,destf);
			written += block_len;
#if BPE_DEBUG>=1
			printf("%zu = %zu, %zu\n",block_num,block_len,cycle);
#endif
		}
		printf("compressed %zu to %zu in %zu blocks (%0.0f%%)\n",read,written,block_num,((double)written/(double)read)*100.0);
	} else {
		// decompress
		size_t read = 0, written = 0;
		size_t block_num = 0;
		for(; ; block_num++) {
			size_t block_len = 0;
			bool eof = false;
			for(int i=0; i<BUF_BYTES; i++) {
				block_len <<= 8;
				const int r = fgetc(srcf);
				if(EOF == r) {
					eof = true;
					break;
				}
				block_len |= r;
				read++;
			}
			if(eof)
				break;
			if(block_len > BLOCK_SIZE) {
				fprintf(stderr,"illegal block_len %zu\n",block_len);
				return 1;
			}
			const size_t cycles = fgetc(srcf);
			read ++;
#if BPE_DEBUG>=1
			printf("%zu = %zu, %zu\n",block_num,block_len,cycles);
#endif
			read += fread(src,1,block_len,srcf);
			for(size_t cycle=0; cycle<cycles; cycle++) {
				size_t in = 0;
				const uint8_t hi = src[in++], lo = src[in++], replace = src[in++];
#if BPE_DEBUG>=2
				printf("%3zu %2zu [",block_num,cycle);
				print_char(hi);
				putchar(':');
				print_char(lo);
				printf("] = ");
				print_char(replace);
#endif
				size_t out = 0;
				for(; in < block_len; in++)
					if(src[in]==replace) {
						dest[out++] = hi;
						dest[out++] = lo;
					} else
						dest[out++] = src[in];
#if BPE_DEBUG>=2
				printf(" = %zu saved, %zu left\n",out-block_len,out);
#endif
				block_len = out;
				std::swap(src,dest);
			}
			written += fwrite(src,1,block_len,destf);
		}
		printf("decompressed %zu to %zu in %zu blocks (%0.0f%%)\n",read,written,block_num,((double)read/(double)written)*100.0);		
	}
	return 0;
}

