HLSL Shader Build Task

This solution is a build task which adds the ability to compile HLSL files referenced in MSBuild projects. Integrates with Visual Studio to make compiling shaders seamless.

C++ (8.3 MB)
 
 
 
 
 
3.7 Star
(3)
5,593 times
Add to favorites
6/25/2011
E-mail Twitter del.icio.us Digg Facebook

Solution explorer

C++
// This is the main DLL file.

#include "stdafx.h"
#include <atlstr.h>

#include "ShaderBuildTask.h"
#include <msclr/marshal.h>

using namespace System;
using namespace System::IO;
using namespace System::Xml;
using namespace System::Runtime::InteropServices;
using namespace System::Text::RegularExpressions;
using namespace msclr::interop; 


namespace ShaderBuildTask {

typedef HRESULT (WINAPI *ShaderCompilerType)
   (
		LPCWSTR							pSrcFile,
        CONST D3DXMACRO*                pDefines,
        LPD3DXINCLUDE                   pInclude,
        LPCSTR                          pFunctionName,
        LPCSTR                          pProfile,
        DWORD                           Flags,
        LPD3DXBUFFER*                   ppShader,
        LPD3DXBUFFER*                   ppErrorMsgs,
        LPD3DXCONSTANTTABLE*            ppConstantTable);

String^ ShaderCompile::GetDxLibraryToUse()
{
	if (!_calculatedDxLibraryToUse)
	{		
		const int maxIndex = 60;
		const int minIndex = 36;
		bool gotOne = false;
		int index = maxIndex;

		// DirectX SDK's install new D3DX libraries by appending a version number.  Here, we start from a high
		// number and count down until we find a library that contains the function we're looking for.  This
		// thus gets us the most recently installed SDK.  If there are no SDKs installed, we _dxLibraryToUse remains
		// NULL and we'll use the statically linked one.
		for (int index = maxIndex; !gotOne && index >= minIndex; index--)
		{
			String^ libName = "d3dx9_" + index.ToString() + ".dll";
			CString libNameAsCString(libName);

			HMODULE dxLibrary = ::LoadLibrary((LPCWSTR)libNameAsCString);
			if (dxLibrary != NULL)
			{
				FARPROC sc = ::GetProcAddress(dxLibrary, "D3DXCompileShaderFromFileW");
				if (sc != NULL)
				{
					gotOne = true;
					_dxLibraryToUse = libName;
				}
				::FreeLibrary(dxLibrary);
			}
		}

		_calculatedDxLibraryToUse = true;
	}

	return _dxLibraryToUse;
}

bool ShaderCompile::Execute()
{
	marshal_context^ context = gcnew marshal_context();
	this->_outputs = gcnew List<ITaskItem ^>();

	// Continue through all sources, even if some fail.  Keep track if we failed.
	bool anyFailed = false;

	for (int i = 0; i < this->Sources->Length; i++)
	{
		ITaskItem^ ti = this->Sources[i];
		
		if (!File::Exists(ti->ItemSpec))
		{
			Log->LogError("ResourceNotFound {0}:", ti->ItemSpec);
			anyFailed = true;
		}
		else
		{			
			LPCWSTR shaderFile = context->marshal_as<LPCWSTR>(ti->ItemSpec);
			LPD3DXBUFFER compiledShader;
			LPD3DXBUFFER errorMessages;
			LPD3DXCONSTANTTABLE constantTable;

			ShaderCompilerType shaderCompiler = ::D3DXCompileShaderFromFile;

			// Try to get the latest if the DX SDK is installed.  Otherwise, back up to the statically linked version.
			String^ libraryToLoad = GetDxLibraryToUse();
			CString libraryToLoadAsCString(libraryToLoad);

			HMODULE dxLibrary = ::LoadLibrary((LPCWSTR)libraryToLoadAsCString); 
			bool gotDynamicOne = false;
			if (dxLibrary != NULL)
			{
				FARPROC sc = ::GetProcAddress(dxLibrary, "D3DXCompileShaderFromFileW");
				shaderCompiler = (ShaderCompilerType)sc;
				gotDynamicOne = true;
			}

			LPCSTR entryPoint = context->marshal_as<LPCSTR>(EntryPoint);
			LPCSTR shaderProfile = context->marshal_as<LPCSTR>(ShaderProfile);

			// initialize flags
			DWORD compilerFlags = 0;

			if (PackMatrixRowMajor)
				compilerFlags |= D3DXSHADER_PACKMATRIX_ROWMAJOR;
			else
				compilerFlags |= D3DXSHADER_PACKMATRIX_COLUMNMAJOR;

			if (Debug)
				compilerFlags |= D3DXSHADER_DEBUG;

			// Not supported on original DX9 compiler
			switch (OptimizationLevel)
			{
			case 0:
				compilerFlags |= D3DXSHADER_OPTIMIZATION_LEVEL0;
				break;
			case 1:
				compilerFlags |= D3DXSHADER_OPTIMIZATION_LEVEL1;
				break;
			case 2:
				compilerFlags |= D3DXSHADER_OPTIMIZATION_LEVEL2;
				break;
			case 3:
				compilerFlags |= D3DXSHADER_OPTIMIZATION_LEVEL3;
				break;
			}

			HRESULT compileResult = 
				shaderCompiler(
					shaderFile,
					NULL, // pDefines
					NULL, // pIncludes
					entryPoint, // entrypoint
					shaderProfile, // "ps_2_0", "vs_2_0", etc.
					compilerFlags, // compiler flags
					&compiledShader,
					&errorMessages,
					&constantTable   // constant table output
					);

			if (!SUCCEEDED(compileResult))
			{
				Log->LogMessage("Compile error {0} on file {1}", compileResult, gcnew String(shaderFile));
				
				char *nativeErrorString = NULL;
				if(errorMessages != NULL)
					nativeErrorString = (char *)(errorMessages->GetBufferPointer());

				String^ managedErrorString = context->marshal_as<String^>(nativeErrorString == NULL ? "Unknown compile error (check flags against DX version)" : nativeErrorString);

				// Need to build up our own error information, since error string from the compiler
				// doesn't identify the source file.

				// Pull the error string from the shader compiler apart.
				// Note that the backslashes are escaped, since C++ needs an escaping of them.  
				String^ subcategory = "Shader";
				String^ dir;
				String^ line;
				String^ col;
				String^ descrip;
				String^ file;
				String^ errorCode = "";
				String^ helpKeyword = "";
				int     lineNum = 0;
				int     colNum = 0;
				bool    parsedLineNum = false;

				if (gotDynamicOne)
				{
					String^ regexString = "(?<directory>[^@]+)memory\\((?<line>[^@]+),(?<col>[^@]+)\\): (?<descrip>[^@]+)";
					Regex^ errorRegex = gcnew Regex(regexString);
					Match^ m = errorRegex->Match(managedErrorString);

					dir     = m->Groups["directory"]->Value;
					line    = m->Groups["line"]->Value;
					col     = m->Groups["col"]->Value;
					descrip = m->Groups["descrip"]->Value;
					file    = String::Concat(dir, ti->ItemSpec);

					parsedLineNum = Int32::TryParse(line, lineNum);
					Int32::TryParse(col, colNum);
				}
				else
				{
					// Statically linked d3dx9.lib's error string is a different format, need to parse that.

					// Example string: (16): error X3018: invalid subscript 'U'
					String^ regexString = "\\((?<line>[^@]+)\\): (?<descrip>[^@]+)";
					Regex^ errorRegex = gcnew Regex(regexString);
					Match^ m = errorRegex->Match(managedErrorString);

					line    = m->Groups["line"]->Value;
					descrip = m->Groups["descrip"]->Value;
					file = ti->ItemSpec;

					parsedLineNum = Int32::TryParse(line, lineNum);

					int colNum = 0;  // no column information supplied
				}

				if (!parsedLineNum)
				{
					// Just use the whole string as the description.
					descrip = managedErrorString;
				}
				Log->LogError(subcategory, errorCode, helpKeyword, file, lineNum, colNum, lineNum, colNum, "{0}", descrip);

				anyFailed = true;

			}
			else
			{
				// Derive output filename from input
				String^ outputFileName = GetOutputFileName(ti->ItemSpec);

				// Create the output task item
				TaskItem^ output = gcnew TaskItem(outputFileName);

    			char *nativeBytestream = (char *)(compiledShader->GetBufferPointer());
				array<unsigned char>^ arr = gcnew array<unsigned char>(compiledShader->GetBufferSize());

				// TODO: Really ugly way to copy from a uchar* to a managed array, but I can't easily figure out the
				// "right" way to do it.
				for (unsigned int i = 0; i < compiledShader->GetBufferSize(); i++)
				{
					arr[i] = nativeBytestream[i];
				}

				File::WriteAllBytes(output->ItemSpec, arr);
				
				this->_outputs->Add(output);
				Log->LogMessage("Source: {0} Target: {1}", ti->ItemSpec, output->ItemSpec);

				if (ExportConstants)
				{
					// Create the output task item for the constant table
					output = gcnew TaskItem(String::Concat(outputFileName, ".constants"));
					WriteConstants(output->ItemSpec, constantTable);
					this->_outputs->Add(output);
					Log->LogMessage("Source: {0} Target: {1}", ti->ItemSpec, output->ItemSpec);
				}
			}

			if (dxLibrary != NULL)
			{
				::FreeLibrary(dxLibrary);
			}
		}
	}  

	return !anyFailed;
}

String^ ShaderCompile::GetOutputFileName(String^ inputFileName)
{
	// Derive output filename from input
	String^ outputFileName = inputFileName;

	// Change path if provided
	if (!String::IsNullOrEmpty(IntermediateOutputPath))
	{
		// Use filename if path contains ".." or is rooted, otherwise use relative path
		if (outputFileName->Contains("..") || Path::IsPathRooted(outputFileName))
			outputFileName = Path::Combine(IntermediateOutputPath, Path::GetFileName(outputFileName));	
		else
			outputFileName = Path::Combine(IntermediateOutputPath, outputFileName);
	}

	// Make sure output directory exists
	const wchar_t sep[] = { Path::DirectorySeparatorChar, 0 };
	if (outputFileName->Contains(gcnew String(sep)))
		Directory::CreateDirectory(Path::GetDirectoryName(outputFileName));

	// Handle conversion of .ps.hlsl format
	outputFileName = outputFileName->Replace(".ps.hlsl", ".ps");
	outputFileName = outputFileName->Replace(".vs.hlsl", ".vs");

	// Change extension to .ps or .vs based on profile
	String^ targetExtension = ".ps";
	if (ShaderProfile->Contains("vs"))
		targetExtension = ".vs";
	outputFileName = Path::ChangeExtension(outputFileName, targetExtension);

	return outputFileName;
}

void ShaderCompile::WriteConstants(String^ fileName, LPD3DXCONSTANTTABLE constantTable)
{	
	// Get constant table description
	D3DXCONSTANTTABLE_DESC desc;
	HRESULT hr = constantTable->GetDesc(&desc);

	// Create XML writer
	XmlWriterSettings^ settings = gcnew XmlWriterSettings();
	settings->Indent = true;
	XmlWriter^ writer = XmlWriter::Create(fileName, settings);

	// Write root element
	writer->WriteStartElement("ShaderConstants");
	Version^ version = gcnew Version(D3DSHADER_VERSION_MAJOR(desc.Version), D3DSHADER_VERSION_MINOR(desc.Version));
	writer->WriteAttributeString("FileFormatVersion", "1.0");
	writer->WriteAttributeString("Version", version->ToString());
	writer->WriteAttributeString("Constants", (gcnew UInt32(desc.Constants))->ToString());	
	writer->WriteAttributeString("Creator", gcnew String(desc.Creator));

	for(unsigned int c = 0; c < desc.Constants; c++)
	{
		// get constant
		D3DXHANDLE constant = constantTable->GetConstant(NULL, c);
		if(constant == NULL)
		{
			Log->LogMessage("Unable to get constant: {0}", c);
			continue;
		}

		// get constant desc count		
		UINT descCount = 0;
		hr = constantTable->GetConstantDesc(constant, NULL, &descCount);
		if (!SUCCEEDED(hr) || descCount == 0)
		{
			Log->LogMessage("Unable to get description count for constant: {0}", c);
			continue;
		}

		// get constant desc
		D3DXCONSTANT_DESC *constantDesc = new D3DXCONSTANT_DESC[descCount];		
		hr = constantTable->GetConstantDesc(constant, &constantDesc[0], &descCount);
		if (!SUCCEEDED(hr))
		{
			delete [] constantDesc;
			Log->LogMessage("Unable to get description for constant: {0}", c);
			continue;
		}

		writer->WriteStartElement("Constant");
		writer->WriteAttributeString("Index", (gcnew UInt32(c))->ToString());
		writer->WriteAttributeString("Descriptions", (gcnew UInt32(descCount))->ToString());

		for(unsigned int d = 0; d < descCount; d++)
		{
			writer->WriteStartElement("Description");

			writer->WriteElementString("Name", gcnew String(constantDesc[d].Name));
			
			switch(constantDesc[d].RegisterSet)
			{
			case D3DXRS_BOOL:
				writer->WriteElementString("RegisterSet", "Bool");
				break;
			case D3DXRS_INT4:
				writer->WriteElementString("RegisterSet", "Int4");
				break;
			case D3DXRS_FLOAT4:
				writer->WriteElementString("RegisterSet", "Float4");
				break;
			case D3DXRS_SAMPLER:
				writer->WriteElementString("RegisterSet", "Sampler");
				break;
			}
			writer->WriteElementString("RegisterIndex", (gcnew UInt32(constantDesc[d].RegisterIndex))->ToString());
			writer->WriteElementString("RegisterCount", (gcnew UInt32(constantDesc[d].RegisterCount))->ToString());

			writer->WriteElementString("Rows", (gcnew UInt32(constantDesc[d].Rows))->ToString());
			writer->WriteElementString("Columns", (gcnew UInt32(constantDesc[d].Columns))->ToString());
			writer->WriteElementString("Elements", (gcnew UInt32(constantDesc[d].Elements))->ToString());
			writer->WriteElementString("StructMembers", (gcnew UInt32(constantDesc[d].StructMembers))->ToString());

			writer->WriteElementString("Bytes", (gcnew UInt32(constantDesc[d].Bytes))->ToString());

			switch(constantDesc[d].Class)
			{
			case D3DXPC_SCALAR:
				writer->WriteElementString("Class", "Scalar");
				break;
			case D3DXPC_VECTOR:
				writer->WriteElementString("Class", "Vector");
				break;
			case D3DXPC_MATRIX_ROWS:
				writer->WriteElementString("Class", "Rows");
				break;
			case D3DXPC_MATRIX_COLUMNS:
				writer->WriteElementString("Class", "Columns");
				break;
			case D3DXPC_OBJECT:
				writer->WriteElementString("Class", "Object");
				break;
			case D3DXPC_STRUCT:
				writer->WriteElementString("Class", "Struct");
				break;
			}

			switch(constantDesc[d].Type)
			{
			case D3DXPT_VOID:
				writer->WriteElementString("Type", "Void");
				break;
			case D3DXPT_BOOL:
				writer->WriteElementString("Type", "Bool");
				break;
			case D3DXPT_INT:
				writer->WriteElementString("Type", "Int");
				break;
			case D3DXPT_FLOAT:
				writer->WriteElementString("Type", "Float");
				break;
			case D3DXPT_STRING:
				writer->WriteElementString("Type", "String");
				break;
			case D3DXPT_TEXTURE:
				writer->WriteElementString("Type", "Texture");
				break;
			case D3DXPT_TEXTURE1D:
				writer->WriteElementString("Type", "Texture1D");
				break;
			case D3DXPT_TEXTURE2D:
				writer->WriteElementString("Type", "Texture2D");
				break;
			case D3DXPT_TEXTURE3D:
				writer->WriteElementString("Type", "Texture3D");
				break;
			case D3DXPT_TEXTURECUBE:
				writer->WriteElementString("Type", "TextureCube");
				break;
			case D3DXPT_SAMPLER:
				writer->WriteElementString("Type", "Sampler");
				break;
			case D3DXPT_SAMPLER1D:
				writer->WriteElementString("Type", "Sampler1D");
				break;
			case D3DXPT_SAMPLER2D:
				writer->WriteElementString("Type", "Sampler2D");
				break;
			case D3DXPT_SAMPLER3D:
				writer->WriteElementString("Type", "Sampler3D");
				break;
			case D3DXPT_SAMPLERCUBE:
				writer->WriteElementString("Type", "SamplerCube");
				break;
			case D3DXPT_PIXELSHADER:
				writer->WriteElementString("Type", "PixelShader");
				break;
			case D3DXPT_VERTEXSHADER:
				writer->WriteElementString("Type", "VertexShader");
				break;
			case D3DXPT_PIXELFRAGMENT:
				writer->WriteElementString("Type", "PixelFragment");
				break;
			case D3DXPT_VERTEXFRAGMENT:
				writer->WriteElementString("Type", "VertexFragment");
				break;
			case D3DXPT_UNSUPPORTED:
				writer->WriteElementString("Type", "Unsupported");
				break;
			}

			writer->WriteEndElement(); // Description
		}

		writer->WriteEndElement(); // Constant
	}

	writer->WriteEndElement(); // ShaderConstants
	writer->Flush();
	writer->Close();
}

}