diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h @@ -21,7 +21,7 @@ /// Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V /// storage classes. The mapping is read from the command-line option. -std::unique_ptr> createMapMemRefStorageClassPass(); +std::unique_ptr> createMapMemRefStorageClassPass(); /// Creates a pass to convert MemRef ops to SPIR-V ops. std::unique_ptr> createConvertMemRefToSPIRVPass(); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -538,7 +538,7 @@ // MemRefToSPIRV //===----------------------------------------------------------------------===// -def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class", "ModuleOp"> { +def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class"> { let summary = "Map numeric MemRef memory spaces to SPIR-V storage classes"; let constructor = "mlir::createMapMemRefStorageClassPass()"; let dependentDialects = ["spirv::SPIRVDialect"]; diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -14,10 +14,10 @@ #include "../PassDetail.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" @@ -86,15 +86,16 @@ Attribute spaceAttr = memRefType.getMemorySpace(); if (spaceAttr && !spaceAttr.isa()) { LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType - << " due to non-IntegerAttr memory space"); + << " due to non-IntegerAttr memory space\n"); return llvm::None; } unsigned space = memRefType.getMemorySpaceAsInt(); auto it = this->memorySpaceMap.find(space); if (it == this->memorySpaceMap.end()) { - LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType - << " due to unable to find memory space in map"); + LLVM_DEBUG(llvm::dbgs() + << "cannot convert " << memRefType + << " due to being unable to find memory space in map\n"); return llvm::None; } @@ -143,10 +144,9 @@ /// Returns true if the given `op` is considered as legal for SPIR-V conversion. static bool isLegalOp(Operation *op) { - if (auto funcOp = dyn_cast(op)) { - FunctionType funcType = funcOp.getFunctionType(); - return llvm::all_of(funcType.getInputs(), isLegalType) && - llvm::all_of(funcType.getResults(), isLegalType); + if (auto funcOp = dyn_cast(op)) { + return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) && + llvm::all_of(funcOp.getResultTypes(), isLegalType); } auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) { @@ -230,7 +230,18 @@ class MapMemRefStorageClassPass final : public MapMemRefStorageClassBase { public: - explicit MapMemRefStorageClassPass() = default; + explicit MapMemRefStorageClassPass() { + memorySpaceMap = spirv::getDefaultVulkanStorageClassMap(); + + LLVM_DEBUG({ + llvm::dbgs() << "memory space to storage class mapping:\n"; + if (memorySpaceMap.empty()) + llvm::dbgs() << " [empty]\n"; + for (auto kv : memorySpaceMap) + llvm::dbgs() << " " << kv.first << " -> " + << spirv::stringifyStorageClass(kv.second) << "\n"; + }); + } explicit MapMemRefStorageClassPass( const spirv::MemorySpaceToStorageClassMap &memorySpaceMap) : memorySpaceMap(memorySpaceMap) {} @@ -251,46 +262,23 @@ if (clientAPI != "vulkan") return failure(); - memorySpaceMap = spirv::getDefaultVulkanStorageClassMap(); - - LLVM_DEBUG({ - llvm::dbgs() << "memory space to storage class mapping:\n"; - if (memorySpaceMap.empty()) - llvm::dbgs() << " [empty]\n"; - for (auto kv : memorySpaceMap) - llvm::dbgs() << " " << kv.first << " -> " - << spirv::stringifyStorageClass(kv.second) << "\n"; - }); - return success(); } void MapMemRefStorageClassPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getOperation(); + Operation *op = getOperation(); auto target = spirv::getMemorySpaceToStorageClassTarget(*context); - spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); - // Use UnrealizedConversionCast as the bridge so that we don't need to pull in - // patterns for other dialects. - auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - auto cast = builder.create(loc, type, inputs); - return Optional(cast.getResult(0)); - }; - converter.addSourceMaterialization(addUnrealizedCast); - converter.addTargetMaterialization(addUnrealizedCast); - target->addLegalOp(); RewritePatternSet patterns(context); spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns); - if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + if (failed(applyFullConversion(op, *target, std::move(patterns)))) return signalPassFailure(); } -std::unique_ptr> -mlir::createMapMemRefStorageClassPass() { +std::unique_ptr> mlir::createMapMemRefStorageClassPass() { return std::make_unique(); } diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -41,7 +41,7 @@ // ----- -// VULKAN-LABEL: func @function_io +// VULKAN-LABEL: func.func @function_io func.func @function_io // VULKAN-SAME: (%{{.+}}: memref>, %{{.+}}: memref<4xi32, #spv.storage_class>) (%arg0: memref, %arg1: memref<4xi32, 3>) @@ -52,7 +52,15 @@ // ----- -// VULKAN: func @region +gpu.module @kernel { +// VULKAN-LABEL: gpu.func @function_io +// VULKAN-SAME: memref<8xi32, #spv.storage_class> +gpu.func @function_io(%arg0 : memref<8xi32>) kernel { gpu.return } +} + +// ----- + +// VULKAN-LABEL: func.func @region func.func @region(%cond: i1, %arg0: memref) { scf.if %cond { // VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref>}