diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/Support/InitLLVM.h" @@ -56,9 +57,9 @@ llvm::InitLLVM y(argc, argv); // Add flags for all the registered translations. - llvm::cl::opt - translationRequested("", llvm::cl::desc("Translation to perform"), - llvm::cl::Required); + llvm::cl::list + translationsRequested("", llvm::cl::desc("Translations to perform"), + llvm::cl::Required); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); registerTranslationCLOptions(); @@ -66,7 +67,7 @@ std::string errorMessage; std::unique_ptr input; - if (auto inputAlignment = translationRequested->getInputAlignment()) + if (auto inputAlignment = translationsRequested[0]->getInputAlignment()) input = openInputFile(inputFilename, *inputAlignment, &errorMessage); else input = openInputFile(inputFilename, &errorMessage); @@ -84,23 +85,54 @@ // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr ownedBuffer, raw_ostream &os) { - MLIRContext context; - context.allowUnregisteredDialects(allowUnregisteredDialects); - context.printOpOnDiagnostic(!verifyDiagnostics); - auto sourceMgr = std::make_shared(); - sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); - - if (!verifyDiagnostics) { - SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); - return (*translationRequested)(sourceMgr, os, &context); + // Temporary buffers for chained translation processing. + std::string dataIn; + std::string dataOut; + LogicalResult result = LogicalResult::success(); + + for (size_t i = 0, e = translationsRequested.size(); i < e; ++i) { + llvm::raw_ostream *stream; + llvm::raw_string_ostream dataStream(dataOut); + + if (i == e - 1) { + // Output last translation to output. + stream = &os; + } else { + // Output translation to temporary data buffer. + stream = &dataStream; + } + + const Translation *translationRequested = translationsRequested[i]; + MLIRContext context; + context.allowUnregisteredDialects(allowUnregisteredDialects); + context.printOpOnDiagnostic(!verifyDiagnostics); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + + if (verifyDiagnostics) { + // In the diagnostic verification flow, we ignore whether the + // translation failed (in most cases, it is expected to fail). + // Instead, we check if the diagnostics were produced as expected. + SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, + &context); + (void)(*translationRequested)(sourceMgr, os, &context); + result = sourceMgrHandler.verify(); + } else { + SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); + result = (*translationRequested)(sourceMgr, *stream, &context); + } + if (failed(result)) + return result; + + if (i < e - 1) { + // If there are further translations, create a new buffer with the + // output data. + dataIn = dataOut; + dataOut.clear(); + ownedBuffer = llvm::MemoryBuffer::getMemBuffer(dataIn); + } } - - // In the diagnostic verification flow, we ignore whether the translation - // failed (in most cases, it is expected to fail). Instead, we check if the - // diagnostics were produced as expected. - SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); - (void)(*translationRequested)(sourceMgr, os, &context); - return sourceMgrHandler.verify(); + return result; }; if (failed(splitAndProcessBuffer(std::move(input), processBuffer, diff --git a/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-translate -no-implicit-module -split-input-file -serialize-spirv -deserialize-spirv %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @array_stride(%arg0 : !spirv.ptr, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { + // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr, stride=128>, StorageBuffer>, i32, i32 + %2 = spirv.AccessChain %arg0[%arg1, %arg2] : !spirv.ptr, stride=128>, StorageBuffer>, i32, i32 + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr, StorageBuffer> + spirv.GlobalVariable @var0 : !spirv.ptr, StorageBuffer> + // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr>, Input> + spirv.GlobalVariable @var1 : !spirv.ptr>, Input> +}