// implements the image editing operations described in "Intrinsic Decompositions for Image Editing"

#include <vector>
#include <iostream>
#include "CImg.h"
#include <algorithm>
#include <string>
#include <fstream>



inline bool file_exists(const std::string& name) {
	std::ifstream f(name.c_str());
	return f.good();
}

void load_photo(const char* filename, std::vector<double> &result, int &W, int &H, int destW = 0, int destH = 0, int mode = 1) { //loads and resize. Mode is the interpolation for resizing (1 for nearest neighbor, 5 for bicubic)

	std::string fi(filename);

#pragma omp critical
	{
		cimg_library::CImg<unsigned char> cimg(fi.c_str());
		W = cimg.width();
		H = cimg.height();

		if (destW*destH != 0) {
			cimg.resize(destW, destH, -100, -100, mode);
			W = destW;
			H = destH;
		}

		result.resize(W*H * 3, 0.);

		if (cimg.spectrum() == 1) { // for grayscale images, just make them color
			for (int i = 0; i < W*H; i++) {
				result[i * 3] = cimg.data()[i];
				result[i * 3 + 1] = cimg.data()[i];
				result[i * 3 + 2] = cimg.data()[i];
			}
		}
		else {
			for (int i = 0; i < W*H; i++) {
				result[i * 3] = cimg.data()[i];
				result[i * 3 + 1] = cimg.data()[i + W*H];
				result[i * 3 + 2] = cimg.data()[i + 2 * W*H];
			}
		}
	}
}


void save_photo(const char* filename, const std::vector<double> &val, int W, int H) {

	std::vector<unsigned char> deinterleaved(W*H * 3);
	for (int i = 0; i < W*H; i++) {
		deinterleaved[i] = std::min(255., std::max(0., val[i * 3]));
		deinterleaved[i + W*H] = std::min(255., std::max(0., val[i * 3 + 1]));
		deinterleaved[i + 2 * W*H] = std::min(255., std::max(0., val[i * 3 + 2]));
	}

#pragma omp critical
	{
		cimg_library::CImg<unsigned char> cimg(&deinterleaved[0], W, H, 1, 3);
		cimg.save(filename);
	}

}

// in many cases, the albedo and shading are scaled before being written to 8bit files, so that R*S = cte*I. This recovers this constant, and reapplies it to the shading layer.
void adjustShading(double* imgAlbedo, double* imginput, double* imgshad, int W, int H) {
	double meanprod = 0;
	double meaninput = 0;

	for (int i = 0; i < W*H * 3; i++) {
		if (imgAlbedo[i] > 3 && imgAlbedo[i] < 252 && imginput[i] > 3 && imginput[i]<252) { // don't use fully saturated values as they are not reliable
			meanprod += imgAlbedo[i] * imgshad[i] / 255.;
			meaninput += imginput[i]; 
		}
	}
	double ratio = meaninput / meanprod;
	if (abs(ratio - 1.0) < 0.1) ratio = 1.; // if it's too close to 1., we assume the original decomposition was exact.
	for (int i = 0; i < W*H * 3; i++) {
		imgshad[i] *= ratio;
	}
}


void poisson_albedo(double* albedo, double* mask, int W, int H, double* result) {

	std::vector<double> tmp(W*H * 3, 0);
	memcpy(&tmp[0], result, W*H * 3 * sizeof(double));
	double* ping = &tmp[0];
	double* pong = result;

	for (int iter = 0; iter < 1000; iter++) {

		for (int i = 0; i < H; i++) {
			for (int j = 0; j < W; j++) {
				if (mask[(i*W + j) * 3] != 0) {
					for (int k = 0; k < 3; k++) {
						double s = (j<(W-1)?ping[(i*W + j + 1) * 3 + k]: ping[(i*W + j - 1) * 3 + k]) + (j>0?ping[(i*W + j - 1) * 3 + k]: ping[(i*W + j + 1) * 3 + k]) + (i<H-1?ping[(i*W + j + W) * 3 + k]:ping[(i*W + j - W) * 3 + k]) + (i>0?ping[(i*W + j - W) * 3 + k]: ping[(i*W + j + W) * 3 + k]);
						pong[(i*W + j) * 3 + k] =  (s) / 4.;
					}
				}
				else {
					pong[(i*W + j) * 3 + 0] = albedo[(i*W + j) * 3 + 0];
					pong[(i*W + j) * 3 + 1] = albedo[(i*W + j) * 3 + 1];
					pong[(i*W + j) * 3 + 2] = albedo[(i*W + j) * 3 + 2];
				}
			}

		}
		std::swap(ping, pong);
	}
}

void poisson_albedo_multiscale(double* albedo, double* mask, int W, int H, double* result) {

	int nlevels = 4;

	std::vector<double> downscaled_result(W*H * 3);
	memcpy(&downscaled_result[0], result, W*H * 3 * sizeof(double));

	for (int i = nlevels; i >= 0; i--) {
		int Wdst = W >> i;
		int Hdst = H >> i;
		cimg_library::CImg<double> res_down(result, 3, W, H, 1, false); // shared images pose problem for resize
		cimg_library::CImg<double> albedo_down(albedo, 3, W, H, 1, false);
		cimg_library::CImg<double> mask_down(mask, 3, W, H, 1, false);


		res_down.resize(3, Wdst, Hdst, 1, 3);  // 3: linear ; 2:moving average ; 5: bicubic		
		albedo_down.resize(3, Wdst, Hdst, 1, 3);
		mask_down.resize(3, Wdst, Hdst, 1, 1);


		poisson_albedo(albedo_down.data(), mask_down.data(), Wdst, Hdst, res_down);

		res_down.resize(3, W, H, 1, 3);
		memcpy(result, res_down.data(), W*H * 3 * sizeof(double));
	}
}


void poisson_replace_albedo(double* albedo, double* new_albedo, double* mask, int W, int H, double* result) {

	std::vector<double> tmp(W*H * 3, 0);
	memcpy(&tmp[0], result, W*H * 3 * sizeof(double));
	double* ping = &tmp[0];
	double* pong = result;

	for (int iter = 0; iter < 1000; iter++) {

		for (int i = 1; i < H-1; i++) {
			for (int j = 1; j < W-1; j++) {
				if (mask[(i*W + j) * 3] != 0) {
					for (int k = 0; k < 3; k++) {
						double s = ping[(i*W + j + 1) * 3 + k] + ping[(i*W + j - 1) * 3 + k] + ping[(i*W + j + W) * 3 + k] + ping[(i*W + j - W) * 3 + k];
						double rhsUp = mask[(i*W + j - W) * 3 + k] != 0 ? new_albedo[(i*W + j - W) * 3 + k] : new_albedo[(i*W + j - W) * 3 + k];
						double rhsDown = mask[(i*W + j + W) * 3 + k] != 0 ? new_albedo[(i*W + j + W) * 3 + k] : new_albedo[(i*W + j + W) * 3 + k];
						double rhsLeft = mask[(i*W + j - 1 ) * 3 + k] != 0 ? new_albedo[(i*W + j - 1) * 3 + k] : new_albedo[(i*W + j - 1) * 3 + k];
						double rhsRight = mask[(i*W + j + 1) * 3 + k] != 0 ? new_albedo[(i*W + j + 1) * 3 + k] : new_albedo[(i*W + j + 1) * 3 + k];
						double rhs = rhsUp + rhsDown + rhsLeft + rhsRight - 4 * new_albedo[(i*W + j) * 3 + k];
						pong[(i*W + j) * 3 + k] = (s - rhs) / 4.;
					}
				}
				else {
					pong[(i*W + j) * 3 + 0] = albedo[(i*W + j) * 3 + 0];
					pong[(i*W + j) * 3 + 1] = albedo[(i*W + j) * 3 + 1];
					pong[(i*W + j) * 3 + 2] = albedo[(i*W + j) * 3 + 2];
				}

		
			}

		}
		std::swap(ping, pong);

	}
}
void poisson_replace_multiscale(double* albedo, double* new_albedo, double* mask, int W, int H, double* result) {

	int nlevels = 4;

	std::vector<double> downscaled_result(W*H * 3);
	memcpy(&downscaled_result[0], result, W*H * 3 * sizeof(double));

	for (int i = nlevels; i >= 0; i--) {
		int Wdst = W >> i;
		int Hdst = H >> i;
		cimg_library::CImg<double> res_down(result, 3, W, H, 1, false); // shared images pose problem for resize
		cimg_library::CImg<double> albedo_down(albedo, 3, W, H, 1, false);
		cimg_library::CImg<double> new_albedo_down(new_albedo, 3, W, H, 1, false);
		cimg_library::CImg<double> mask_down(mask, 3, W, H, 1, false);


		res_down.resize(3, Wdst, Hdst, 1, 3);  // 3: linear ; 2:moving average ; 5: bicubic		
		albedo_down.resize(3, Wdst, Hdst, 1, 3);
		new_albedo_down.resize(3, Wdst, Hdst, 1, 3);
		mask_down.resize(3, Wdst, Hdst, 1, 1);


		poisson_replace_albedo(albedo_down.data(), new_albedo_down.data(), mask_down.data(), Wdst, Hdst, res_down);

		res_down.resize(3, W, H, 1, 3);
		memcpy(result, res_down.data(), W*H * 3 * sizeof(double));
	}
}


void remove_texture(const char* input, const char* albedo, const char* mask, const char* shading, const char* outProduct) {
	std::vector<double> imgAlbedo, imgshad;
	int W, H, tmpW, tmpH;

	load_photo(albedo, imgAlbedo, W, H);
	load_photo(shading, imgshad, W, H);
	
	std::vector<double> imgmask(W*H*3), imginput(W*H*3);
	load_photo(mask, imgmask, tmpW, tmpH, W, H, 1);	
	load_photo(input, imginput, tmpW, tmpH, W, H, 5);

	adjustShading(&imgAlbedo[0], &imginput[0], &imgshad[0], W, H);

	std::vector<double> result = imgAlbedo;
	poisson_albedo_multiscale(&imgAlbedo[0], &imgmask[0], W, H, &result[0]);

	std::vector<double> product(W*H * 3);
	for (int i = 0; i < W*H * 3; i++) {
		product[i] = result[i] * imgshad[i]/255.;
	}
	save_photo(outProduct, product, W, H);
}



void process_shading(const char* input, const char* albedo, const char* mask, const char* shading, const char* outProduct) {
	std::vector<double> imgAlbedo, imgshad;
	int W, H, tmpW, tmpH;

	load_photo(albedo, imgAlbedo, W, H);
	load_photo(shading, imgshad, W, H);

	std::vector<double> imgmask(W*H * 3), imginput(W*H * 3);
	load_photo(mask, imgmask, tmpW, tmpH, W, H, 1);
	load_photo(input, imginput, tmpW, tmpH, W, H, 5);

	adjustShading(&imgAlbedo[0], &imginput[0], &imgshad[0], W, H);

	std::vector<double> result = imgshad;
	poisson_albedo_multiscale(&imgshad[0], &imgmask[0], W, H, &result[0]);

	std::vector<double> product(W*H * 3);
	for (int i = 0; i < W*H * 3; i++) {
		product[i] = result[i] * imgAlbedo[i] / 255.;
	}
	save_photo(outProduct, product, W, H);
}


void process_both(const char* input, const char* albedo, const char* mask, const char* mask2, const char* shading, const char* outProduct) {
	std::vector<double> imgAlbedo, imgshad;
	int W, H, tmpW, tmpH;

	load_photo(albedo, imgAlbedo, W, H);
	load_photo(shading, imgshad, W, H);

	std::vector<double> imgmask(W*H * 3), imgmask2(W*H * 3), imginput(W*H * 3);
	load_photo(mask, imgmask, tmpW, tmpH, W, H, 1);
	load_photo(mask2, imgmask2, tmpW, tmpH, W, H, 1);
	load_photo(input, imginput, tmpW, tmpH, W, H, 5);

	adjustShading(&imgAlbedo[0], &imginput[0], &imgshad[0], W, H);

	std::vector<double> result = imgshad;
	poisson_albedo_multiscale(&imgshad[0], &imgmask[0], W, H, &result[0]);
	std::vector<double> result2 = imgAlbedo;
	poisson_albedo_multiscale(&imgAlbedo[0], &imgmask2[0], W, H, &result2[0]);

	std::vector<double> product(W*H * 3);
	for (int i = 0; i < W*H * 3; i++) {
		product[i] = result[i] * result2[i] / 255.;
	}
	save_photo(outProduct, product, W, H);
}

void replace_albedo(const char* input, const char* albedo, const char* new_albedo, const char* mask, const char* occlusions, const char* shading, const char* outProduct) {
	std::vector<double> imgAlbedo, imgshad;
	int W, H, tmpW, tmpH;

	load_photo(albedo, imgAlbedo, W, H);
	load_photo(shading, imgshad, W, H);


	std::vector<double> imgmask(W*H * 3), imginput(W*H * 3), imgnewalbedo(W*H*3), imgocc(W*H*3);
	load_photo(mask, imgmask, tmpW, tmpH, W, H, 1);
	load_photo(input, imginput, tmpW, tmpH, W, H, 5);
	load_photo(new_albedo, imgnewalbedo, tmpW, tmpH, W, H, 5);

	adjustShading(&imgAlbedo[0], &imginput[0], &imgshad[0], W, H); // the decomposition is often such that there is a scaling constant


	std::vector<double> result = imgAlbedo;
	poisson_replace_multiscale(&imgAlbedo[0], &imgnewalbedo[0], &imgmask[0], W, H, &result[0]);

	std::vector<double> product(W*H * 3);
	for (int i = 0; i < W*H * 3; i++) {
		product[i] = result[i] * imgshad[i] / 255.;
	}
	if (file_exists(occlusions)) {
		load_photo(occlusions, imgocc, tmpW, tmpH, W, H, 1);
		for (int i = 0; i < W*H * 3; i++)
			product[i] = product[i] * (1 - imgocc[i] / 255.) + imgocc[i] / 255. * imginput[i];
	}

	save_photo(outProduct, product, W, H);
}

template<typename T>
void blur(const T* img_in, T* img_out, int W, int H, int nbchans, T sigma) {

	std::vector<T> result(W*H*nbchans);
	int ker = (int)(sigma*2.5);
	std::vector<T> gauss1d(ker * 2 + 1);
	double c = sigma*sqrt(2 * 3.1416);
	for (int i = 0; i<ker * 2 + 1; i++) {
		gauss1d[i] = exp(-(i - ker)*(i - ker) / (2 * sigma*sigma)) / c;
	}
	for (int i = 0; i<H; i++) {
#pragma omp parallel for schedule(static)
		for (int j = 0; j<W; j++) {
			for (int k = 0; k<nbchans; k++) {
				double s = 0;
				for (int l = 0; l<ker * 2 + 1; l++) {
					int id = abs(i + l - ker); if (id >= H) id = 2 * H - id - 2;
					s += img_in[(id*W + j)*nbchans + k] * gauss1d[l];
				}
				result[(i*W + j)*nbchans + k] = s;
			}
		}
	}
	for (int i = 0; i<H; i++) {
#pragma omp parallel for schedule(static)
		for (int j = 0; j<W; j++) {
			for (int k = 0; k<nbchans; k++) {
				double s = 0;
				for (int l = 0; l<ker * 2 + 1; l++) {
					int id = abs(j + l - ker); if (id >= W) id = 2 * W - id - 2;
					s += result[(i*W + id)*nbchans + k] * gauss1d[l];
				}
				img_out[(i*W + j)*nbchans + k] = s;
			}
		}
	}
}


void blur_shading(const char* input, const char* albedo, const char* mask, const char* shading, const char* outProduct, double blursigma) {
	std::vector<double> imgAlbedo, imgshad;
	int W, H, tmpW, tmpH;

	load_photo(albedo, imgAlbedo, W, H);
	load_photo(shading, imgshad, W, H);


	std::vector<double> imgmask(W*H * 3), imginput(W*H * 3);
	load_photo(mask, imgmask, tmpW, tmpH, W, H, 1);

	load_photo(input, imginput, tmpW, tmpH, W, H, 5);

	adjustShading(&imgAlbedo[0], &imginput[0], &imgshad[0], W, H);

	std::vector<double> result = imgAlbedo;
	std::vector<double> resultmask = imgAlbedo;
	blur<double>(&imgshad[0], &result[0], W, H, 3, blursigma*W/1280.);
	blur<double>(&imgmask[0], &resultmask[0], W, H, 3, blursigma*W / 1280.); // makes the sigma resolution dependent

	for (int i = 0; i < W*H*3; i++) {
		result[i] = imgshad[i]*(1- resultmask[i]/255.) + resultmask[i] / 255. * result[i];
	}

	std::vector<double> product(W*H * 3);
	for (int i = 0; i < W*H * 3; i++) {
		product[i] = result[i] * imgAlbedo[i] / 255.;
	}
	save_photo(outProduct, product, W, H);
}

int main()
{
	remove_texture("3912994232_dac125a0d3_o.jpg", "3912994232_dac125a0d3_o_albedo.png", "3912994232_dac125a0d3_o_mask.png", "3912994232_dac125a0d3_o_shading.jpg", "3912994232_dac125a0d3_o_Application.jpg"); // remove a logo on a shirt, or a texture in general while keeping the shading
	
	blur_shading("7158418_2215ef4a6e_o.jpg", "7158418_2215ef4a6e_o_albedo.png", "7158418_2215ef4a6e_o_mask.png", "7158418_2215ef4a6e_o_shading.png", "7158418_2215ef4a6e_o_Application.jpg", 3.); // blurs the shading layer to remove wrinkles
	
	replace_albedo("5540131495_059e17f4f0_o.jpg", "5540131495_059e17f4f0_o_albedo.png", "5540131495_059e17f4f0_o_replace.png", "5540131495_059e17f4f0_o_mask.png", "5540131495_059e17f4f0_o_occlusion.png", "5540131495_059e17f4f0_o_shading.png", "5540131495_059e17f4f0_o_Application.jpg");  // replace the reflectance component with another, with respect to the given mask
	
	process_both("7390090600_033e3c9c93_o.jpg", "7390090600_033e3c9c93_o_albedo.png", "7390090600_033e3c9c93_o_mask.png", "7390090600_033e3c9c93_o_mask2.png", "7390090600_033e3c9c93_o_shading.png", "7390090600_033e3c9c93_o_Application.jpg");  // removes a shadow *and* a texture in an image (one mask for each ; the first one is for the shadow to be removed, the second for the texture)


	return 0;
}