diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h @@ -13,6 +13,7 @@ #ifndef MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRVPASS_H #define MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRVPASS_H +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -21,6 +22,15 @@ /// Creates a pass to convert MemRef ops to SPIR-V ops. std::unique_ptr> createConvertMemRefToSPIRVPass(); +/// Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V +/// storage classes. The mapping is given as `memorySpaceMap`. +std::unique_ptr> createMapMemRefStorageClassPass( + const DenseMap &memorySpaceMap); + +/// Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V +/// storage classes. The mapping is read from the command-line option. +std::unique_ptr> createMapMemRefStorageClassPass(); + } // namespace mlir #endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -529,6 +529,17 @@ // MemRefToSPIRV //===----------------------------------------------------------------------===// +def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class", "ModuleOp"> { + let summary = "Map numeric MemRef memory spaces to SPIR-V storage classes"; + let constructor = "mlir::createMapMemRefStorageClassPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; + let options = [ + Option<"mappings", "mappings", "std::string", /*default=*/"", + "A comma-separated list of memory space to storage class mappings; " + "for example, '0=StorageClass,1=Uniform,2=Workgroup'"> + ]; +} + def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> { let summary = "Convert MemRef dialect to SPIR-V dialect"; let constructor = "mlir::createConvertMemRefToSPIRVPass()"; diff --git a/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRMemRefToSPIRV + MapMemRefStorageClassPass.cpp MemRefToSPIRV.cpp MemRefToSPIRVPass.cpp diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -0,0 +1,266 @@ +//===- MapMemRefStorageCLassPass.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to map numeric MemRef memory spaces to +// symbolic ones defined in the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-map-memref-storage-class" + +using namespace mlir; + +using MemorySpaceToStorageClassMap = DenseMap; + +/// Parses the memory space mapping string `memorySpaceMapStr `and writes the +/// mappings encoded inside to `memorySpaceMap`. +static bool parseMappingStr(StringRef memorySpaceMapStr, + MemorySpaceToStorageClassMap &memorySpaceMap) { + memorySpaceMap.clear(); + SmallVector mappings; + llvm::SplitString(memorySpaceMapStr, mappings, ","); + for (StringRef mapping : mappings) { + StringRef key, value; + std::tie(key, value) = mapping.split('='); + unsigned space; + if (!llvm::to_integer(key, space)) { + LLVM_DEBUG(llvm::dbgs() + << "failed to parse mapping string key: " << key << "\n"); + memorySpaceMap.clear(); + return false; + } + Optional storage = spirv::symbolizeStorageClass(value); + if (!storage) { + LLVM_DEBUG(llvm::dbgs() + << "failed to parse mapping string value: " << value << "\n"); + memorySpaceMap.clear(); + return false; + } + memorySpaceMap[space] = *storage; + } + return true; +} + +/// Returns true if the given `type` is considered as legal for SPIR-V +/// conversion. +static bool isLegalType(Type type) { + if (auto memRefType = type.dyn_cast()) { + Attribute spaceAttr = memRefType.getMemorySpace(); + return spaceAttr && spaceAttr.isa(); + } + return true; +} + +/// Returns true if the given `attr` is considered as legal for SPIR-V +/// conversion. +static bool isLegalAttr(Attribute attr) { + if (auto typeAttr = attr.dyn_cast()) + return isLegalType(typeAttr.getValue()); + return true; +} + +/// Returns true if the given `op` is considered as legal for SPIR-V conversion. +static bool isLegalOp(Operation *op) { + if (auto funcOp = dyn_cast(op)) { + FunctionType funcType = funcOp.getFunctionType(); + return llvm::all_of(funcType.getInputs(), isLegalType) && + llvm::all_of(funcType.getResults(), isLegalType); + } + + auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) { + return attr.getValue(); + }); + + return llvm::all_of(op->getOperandTypes(), isLegalType) && + llvm::all_of(op->getResultTypes(), isLegalType) && + llvm::all_of(attrs, isLegalAttr); +} + +namespace { + +/// Type converter for converting numeric MemRef memory spaces into SPIR-V +/// symbolic ones. +class MemRefTypeConverter final : public TypeConverter { +public: + MemRefTypeConverter(const MemorySpaceToStorageClassMap &memorySpaceMap); + +private: + const MemorySpaceToStorageClassMap &memorySpaceMap; +}; + +/// Converts any op that has operands/results/attributes with numeric MemRef +/// memory spaces. +struct MapMemRefStoragePattern final : public ConversionPattern { + MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +class MapMemRefStorageClassPass final + : public MapMemRefStorageClassBase { +public: + explicit MapMemRefStorageClassPass() = default; + explicit MapMemRefStorageClassPass( + const MemorySpaceToStorageClassMap &memorySpaceMap) + : memorySpaceMap(memorySpaceMap) {} + + LogicalResult initializeOptions(StringRef options) override; + + void runOnOperation() override; + +private: + MemorySpaceToStorageClassMap memorySpaceMap; +}; + +} // namespace + +MemRefTypeConverter::MemRefTypeConverter( + const MemorySpaceToStorageClassMap &memorySpaceMap) + : memorySpaceMap(memorySpaceMap) { + // Pass through for all other types. + addConversion([](Type type) { return type; }); + + addConversion([this](BaseMemRefType memRefType) -> Optional { + // Expect IntegerAttr memory spaces. The attribute can be missing for the + // case of memory space == 0. + Attribute spaceAttr = memRefType.getMemorySpace(); + if (spaceAttr && !spaceAttr.isa()) { + LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType + << " due to non-IntegerAttr memory space"); + return llvm::None; + } + + unsigned space = memRefType.getMemorySpaceAsInt(); + auto it = this->memorySpaceMap.find(space); + if (it == this->memorySpaceMap.end()) { + LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType + << " due to unable to find memory space in map"); + return llvm::None; + } + + auto storageAttr = + spirv::StorageClassAttr::get(memRefType.getContext(), it->second); + if (auto rankedType = memRefType.dyn_cast()) { + return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), + rankedType.getLayout(), storageAttr); + } + return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr); + }); + + addConversion([this](FunctionType type) { + SmallVector inputs, results; + inputs.reserve(type.getNumInputs()); + results.reserve(type.getNumResults()); + for (Type input : type.getInputs()) + inputs.push_back(convertType(input)); + for (Type result : type.getResults()) + results.push_back(convertType(result)); + return FunctionType::get(type.getContext(), inputs, results); + }); +} + +LogicalResult MapMemRefStoragePattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + llvm::SmallVector newAttrs; + newAttrs.reserve(op->getAttrs().size()); + for (auto attr : op->getAttrs()) { + if (auto typeAttr = attr.getValue().dyn_cast()) { + auto newAttr = getTypeConverter()->convertType(typeAttr.getValue()); + newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); + } else { + newAttrs.push_back(attr); + } + } + + llvm::SmallVector newResults; + (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults); + + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, newAttrs, op->getSuccessors()); + + for (Region ®ion : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + (void)getTypeConverter()->convertSignatureArgs( + newRegion->getArgumentTypes(), result); + rewriter.applySignatureConversion(newRegion, result); + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) { + if (failed(Pass::initializeOptions(options))) + return failure(); + + if (!parseMappingStr(mappings, memorySpaceMap)) + return failure(); + + LLVM_DEBUG({ + llvm::dbgs() << "memory space to storage class mapping:\n"; + if (memorySpaceMap.empty()) + llvm::dbgs() << " [empty]\n"; + for (auto kv : memorySpaceMap) + llvm::dbgs() << " " << kv.first << " -> " + << spirv::stringifyStorageClass(kv.second) << "\n"; + }); + + return success(); +} + +void MapMemRefStorageClassPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.markUnknownOpDynamicallyLegal(isLegalOp); + + MemRefTypeConverter converter(memorySpaceMap); + // Use UnrealizedConversionCast as the bridge so that we don't need to pull in + // patterns for other dialects. + auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto cast = builder.create(loc, type, inputs); + return Optional(cast.getResult(0)); + }; + converter.addSourceMaterialization(addUnrealizedCast); + converter.addTargetMaterialization(addUnrealizedCast); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, converter); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr> mlir::createMapMemRefStorageClassPass( + const MemorySpaceToStorageClassMap &memorySpaceMap) { + return std::make_unique(memorySpaceMap); +} + +std::unique_ptr> +mlir::createMapMemRefStorageClassPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -126,8 +126,8 @@ return spirv::Scope::Device; case spirv::StorageClass::Workgroup: return spirv::Scope::Workgroup; - default: { - } + default: + break; } return {}; } diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='mappings=0=StorageBuffer,1=Uniform,2=Workgroup,3=PushConstant' -verify-diagnostics %s -o - | FileCheck %s + +// Mappings: +// 0 -> StorageBuffer (12) +// 2 -> Workgroup (4) +// 1 -> Uniform (2) +// 3 -> PushConstant (9) +// TODO: create a StorageClass wrapper class so we can print the symbolc +// storage class (instead of the backing IntegerAttr) and be able to +// round trip the IR. + +// CHECK-LABEL: func @operand_result +func.func @operand_result() { + // CHECK: memref + %0 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<4xi32, 2 : i32> + %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) + // CHECK: memref + %2 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<*xf16, 9 : i32> + %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 3>) + + + "dialect.memref_consumer"(%0) : (memref) -> () + // CHECK: memref<4xi32, 2 : i32> + "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () + // CHECK: memref + "dialect.memref_consumer"(%2) : (memref) -> () + // CHECK: memref<*xf16, 9 : i32> + "dialect.memref_consumer"(%3) : (memref<*xf16, 3>) -> () + + return +} + +// ----- + +// CHECK-LABEL: func @type_attribute +func.func @type_attribute() { + // CHECK: attr = memref + "dialect.memref_producer"() { attr = memref } : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @function_io +func.func @function_io + // CHECK-SAME: (%{{.+}}: memref, %{{.+}}: memref<4xi32, 9 : i32>) + (%arg0: memref, %arg1: memref<4xi32, 3>) + // CHECK-SAME: -> (memref, memref<4xi32, 9 : i32>) + -> (memref, memref<4xi32, 3>) { + return %arg0, %arg1: memref, memref<4xi32, 3> +} + +// ----- + +// CHECK: func @region +func.func @region(%cond: i1, %arg0: memref) { + scf.if %cond { + // CHECK: "dialect.memref_consumer"(%{{.+}}) {attr = memref} + // CHECK-SAME: (memref) -> memref + %0 = "dialect.memref_consumer"(%arg0) { attr = memref } : (memref) -> (memref) + } + return +} + +// ----- + +// CHECK-LABEL: func @non_memref_types +func.func @non_memref_types(%arg: f32) -> f32 { + // CHECK: "dialect.op"(%{{.+}}) {attr = 16 : i64} : (f32) -> f32 + %0 = "dialect.op"(%arg) { attr = 16 } : (f32) -> (f32) + return %0 : f32 +} + +// ----- + +func.func @missing_mapping() { + // expected-error @+1 {{failed to legalize}} + %0 = "dialect.memref_producer"() : () -> (memref) + return +}