diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -132,6 +132,11 @@ unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout); + /// Return the LLVM address space corresponding to the memory space of the + /// memref type `type` or failure if the memory space cannot be converted to + /// an integer. + FailureOr getMemRefAddressSpace(BaseMemRefType type); + /// Check if a memref type can be converted to a bare pointer. static bool canConvertToBarePtr(BaseMemRefType type); diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td @@ -25,7 +25,8 @@ let dependentDialects = [ - "arith::ArithDialect" + "arith::ArithDialect", + "gpu::GPUDialect" ]; let useDefaultAttributePrinterParser = 1; let useFoldAPI = kEmitFoldAdaptorFolder; diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -61,23 +61,6 @@ } namespace gpu { -/// A function that maps a MemorySpace enum to a target-specific integer value. -using MemorySpaceMapping = - std::function; - -/// Populates type conversion rules for lowering memory space attributes to -/// numeric values. -void populateMemorySpaceAttributeTypeConversions( - TypeConverter &typeConverter, const MemorySpaceMapping &mapping); - -/// Populates patterns to lower memory space attributes to numeric values. -void populateMemorySpaceLoweringPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); - -/// Populates legality rules for lowering memory space attriutes to numeric -/// values. -void populateLowerMemorySpaceOpLegality(ConversionTarget &target); - /// Returns the default annotation name for GPU binary blobs. std::string getDefaultGpuBinaryAnnotation(); diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td @@ -37,23 +37,4 @@ let dependentDialects = ["mlir::gpu::GPUDialect"]; } -def GPULowerMemorySpaceAttributesPass - : Pass<"gpu-lower-memory-space-attributes"> { - let summary = "Assign numeric values to memref memory space symbolic placeholders"; - let description = [{ - Updates all memref types that have a memory space attribute - that is a `gpu::AddressSpaceAttr`. These attributes are - changed to `IntegerAttr`'s using a mapping that is given in the - options. - }]; - let options = [ - Option<"privateAddrSpace", "private", "unsigned", "5", - "private address space numeric value">, - Option<"workgroupAddrSpace", "workgroup", "unsigned", "3", - "workgroup address space numeric value">, - Option<"globalAddrSpace", "global", "unsigned", "1", - "global address space numeric value"> - ]; -} - #endif // MLIR_DIALECT_GPU_PASSES diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -21,6 +21,7 @@ namespace mlir { // Forward declarations. +class Attribute; class Block; class ConversionPatternRewriter; class MLIRContext; @@ -87,6 +88,34 @@ SmallVector argTypes; }; + /// The general result of a type attribute conversion callback, allowing + /// for early termination. The default constructor creates the na case. + class AttributeConversionResult { + public: + constexpr AttributeConversionResult() : impl() {} + AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {} + + static AttributeConversionResult result(Attribute attr); + static AttributeConversionResult na(); + static AttributeConversionResult abort(); + + bool hasResult() const; + bool isNa() const; + bool isAbort() const; + + Attribute getResult() const; + + private: + AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {} + + llvm::PointerIntPair impl; + // Note that na is 0 so that we can use PointerIntPair's default + // constructor. + static constexpr unsigned naTag = 0; + static constexpr unsigned resultTag = 1; + static constexpr unsigned abortTag = 2; + }; + /// Register a conversion function. A conversion function must be convertible /// to any of the following forms(where `T` is a class derived from `Type`: /// * std::optional(T) @@ -156,6 +185,34 @@ wrapMaterialization(std::forward(callback))); } + /// Register a conversion function for attributes within types. Type + /// converters may call this function in order to allow hoking into the + /// translation of attributes that exist within types. For example, a type + /// converter for the `memref` type could use these conversions to convert + /// memory spaces or layouts in an extensible way. + /// + /// The conversion functions take a non-null Type or subclass of Type and a + /// non-null Attribute (or subclass of Attribute), and returns a + /// `AttributeConversionResult`. This result can either contan an `Attribute`, + /// which may be `nullptr`, representing the conversion's success, + /// `AttributeConversionResult::na()` (the default empty value), indicating + /// that the conversion function did not apply and that further conversion + /// functions should be checked, or `AttributeConversionResult::abort()` + /// indicating that the conversion process should be aborted. + /// + /// Registered conversion functions are callled in the reverse of the order in + /// which they were registered. + template < + typename FnT, + typename T = + typename llvm::function_traits>::template arg_t<0>, + typename A = + typename llvm::function_traits>::template arg_t<1>> + void addTypeAttributeConversion(FnT &&callback) { + registerTypeAttributeConversion( + wrapTypeAttributeConversion(std::forward(callback))); + } + /// Convert the given type. This function should return failure if no valid /// conversion exists, success otherwise. If the new set of types is empty, /// the type is removed and any usages of the existing value are expected to @@ -226,6 +283,12 @@ resultType, inputs); } + /// Convert an attribute present `attr` from within the type `type` using + /// the registered conversion functions. If no applicable conversion has been + /// registered, return std::nullopt. Note that the empty attribute/`nullptr` + /// is a valid return value for this function. + std::optional convertTypeAttribute(Type type, Attribute attr); + private: /// The signature of the callback used to convert a type. If the new set of /// types is empty, the type is removed and any usages of the existing value @@ -237,6 +300,10 @@ using MaterializationCallbackFn = std::function( OpBuilder &, Type, ValueRange, Location)>; + /// The signature of the callback used to convert a type attribute. + using TypeAttributeConversionCallbackFn = + std::function; + /// Attempt to materialize a conversion using one of the provided /// materialization functions. Value materializeConversion( @@ -313,6 +380,32 @@ }; } + /// Generate a wrapper for the given memory space conversion callback. The + /// callback may take any subclass of `Attribute` and the wrapper will check + /// for the target attribute to be of the expected class before calling the + /// callback. + template + TypeAttributeConversionCallbackFn + wrapTypeAttributeConversion(FnT &&callback) { + return [callback = std::forward(callback)]( + Type type, Attribute attr) -> AttributeConversionResult { + if (T derivedType = type.dyn_cast()) { + if (A derivedAttr = attr.dyn_cast_or_null()) + return callback(derivedType, derivedAttr); + } + return AttributeConversionResult::na(); + }; + } + + /// Register a memory space conversion, clearing caches. + void + registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) { + typeAttributeConversions.emplace_back(std::move(callback)); + // Clear type conversions in case a memory space is lingering inside. + cachedDirectConversions.clear(); + cachedMultiConversions.clear(); + } + /// The set of registered conversion functions. SmallVector conversions; @@ -321,6 +414,9 @@ SmallVector sourceMaterializations; SmallVector targetMaterializations; + /// The list of registered type attribute conversion functions. + SmallVector typeAttributeConversions; + /// A set of cached conversions to avoid recomputing in the common case. /// Direct 1-1 conversions are the most common, so this cache stores the /// successful 1-1 conversions as well as all failed conversions. diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -112,6 +112,14 @@ } }; +/// A function that maps a MemorySpace enum to a target-specific integer value. +using MemorySpaceMapping = + std::function; + +/// Populates memory space attribute conversion rules for lowering +/// gpu.address_space to integer values. +void populateGpuMemorySpaceAttributeConversions( + TypeConverter &typeConverter, const MemorySpaceMapping &mapping); } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -8,6 +8,7 @@ #include "GPUOpsLowering.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/STLExtras.h" @@ -474,3 +475,18 @@ rewriter.replaceOp(op, result); return success(); } + +static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { + return IntegerAttr::get(IntegerType::get(ctx, 64), space); +} + +void mlir::populateGpuMemorySpaceAttributeConversions( + TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { + typeConverter.addTypeAttributeConversion( + [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) { + gpu::AddressSpace memorySpace = memorySpaceAttr.getValue(); + unsigned addressSpace = mapping(memorySpace); + return wrapNumericMemorySpace(memorySpaceAttr.getContext(), + addressSpace); + }); +} diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -185,38 +185,26 @@ return signalPassFailure(); } - // MemRef conversion for GPU to NVVM lowering. - { - RewritePatternSet patterns(m.getContext()); - TypeConverter typeConverter; - typeConverter.addConversion([](Type t) { return t; }); - // NVVM uses alloca in the default address space to represent private - // memory allocations, so drop private annotations. NVVM uses address - // space 3 for shared memory. NVVM uses the default address space to - // represent global memory. - gpu::populateMemorySpaceAttributeTypeConversions( - typeConverter, [](gpu::AddressSpace space) -> unsigned { - switch (space) { - case gpu::AddressSpace::Global: - return static_cast( - NVVM::NVVMMemorySpace::kGlobalMemorySpace); - case gpu::AddressSpace::Workgroup: - return static_cast( - NVVM::NVVMMemorySpace::kSharedMemorySpace); - case gpu::AddressSpace::Private: - return 0; - } - llvm_unreachable("unknown address space enum value"); - return 0; - }); - gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns); - ConversionTarget target(getContext()); - gpu::populateLowerMemorySpaceOpLegality(target); - if (failed(applyFullConversion(m, target, std::move(patterns)))) - return signalPassFailure(); - } - LLVMTypeConverter converter(m.getContext(), options); + // NVVM uses alloca in the default address space to represent private + // memory allocations, so drop private annotations. NVVM uses address + // space 3 for shared memory. NVVM uses the default address space to + // represent global memory. + populateGpuMemorySpaceAttributeConversions( + converter, [](gpu::AddressSpace space) -> unsigned { + switch (space) { + case gpu::AddressSpace::Global: + return static_cast( + NVVM::NVVMMemorySpace::kGlobalMemorySpace); + case gpu::AddressSpace::Workgroup: + return static_cast( + NVVM::NVVMMemorySpace::kSharedMemorySpace); + case gpu::AddressSpace::Private: + return 0; + } + llvm_unreachable("unknown address space enum value"); + return 0; + }); // Lowering for MMAMatrixType. converter.addConversion([&](gpu::MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -130,33 +130,21 @@ (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); } - // Apply memory space lowering. The target uses 3 for workgroup memory and 5 - // for private memory. - { - RewritePatternSet patterns(ctx); - TypeConverter typeConverter; - typeConverter.addConversion([](Type t) { return t; }); - gpu::populateMemorySpaceAttributeTypeConversions( - typeConverter, [](gpu::AddressSpace space) { - switch (space) { - case gpu::AddressSpace::Global: - return 1; - case gpu::AddressSpace::Workgroup: - return 3; - case gpu::AddressSpace::Private: - return 5; - } - llvm_unreachable("unknown address space enum value"); - return 0; - }); - ConversionTarget target(getContext()); - gpu::populateLowerMemorySpaceOpLegality(target); - gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns); - if (failed(applyFullConversion(m, target, std::move(patterns)))) - return signalPassFailure(); - } - LLVMTypeConverter converter(ctx, options); + populateGpuMemorySpaceAttributeConversions( + converter, [](gpu::AddressSpace space) { + switch (space) { + case gpu::AddressSpace::Global: + return 1; + case gpu::AddressSpace::Workgroup: + return 3; + case gpu::AddressSpace::Private: + return 5; + } + llvm_unreachable("unknown address space enum value"); + return 0; + }); + RewritePatternSet llvmPatterns(ctx); mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -109,8 +109,10 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = typeConverter->convertType(elementType); - return LLVM::LLVMPointerType::get(structElementType, - type.getMemorySpaceAsInt()); + auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type); + if (failed(addressSpace)) + return {}; + return LLVM::LLVMPointerType::get(structElementType, *addressSpace); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -158,6 +158,10 @@ return builder.create(loc, resultType, inputs) .getResult(0); }); + + // Integer memory spaces map to themselves. + addTypeAttributeConversion( + [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; }); } /// Returns the MLIR context. @@ -311,8 +315,16 @@ Type elementType = convertType(type.getElementType()); if (!elementType) return {}; - auto ptrTy = - LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); + FailureOr addressSpace = getMemRefAddressSpace(type); + if (failed(addressSpace)) { + emitError(UnknownLoc::get(type.getContext()), + "conversion of memref memory space ") + << type.getMemorySpace() + << " to integer address space " + "failed. Consider adding memory space conversions."; + return {}; + } + auto ptrTy = LLVM::LLVMPointerType::get(elementType, *addressSpace); auto indexTy = getIndexType(); SmallVector results = {ptrTy, ptrTy, indexTy}; @@ -330,7 +342,7 @@ unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) { // Compute the descriptor size given that of its components indicated above. - unsigned space = type.getMemorySpaceAsInt(); + unsigned space = *getMemRefAddressSpace(type); return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType()); } @@ -363,7 +375,7 @@ LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) { // Compute the descriptor size given that of its components indicated above. - unsigned space = type.getMemorySpaceAsInt(); + unsigned space = *getMemRefAddressSpace(type); return layout.getTypeSize(getIndexType()) + llvm::divideCeil(getPointerBitwidth(space), 8); } @@ -375,6 +387,21 @@ getUnrankedMemRefDescriptorFields()); } +FailureOr +LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) { + if (!type.getMemorySpace()) // Default memory space -> 0. + return 0; + Optional converted = + convertTypeAttribute(type, type.getMemorySpace()); + if (!converted) + return failure(); + if (!(*converted)) // Conversion to default is 0. + return 0; + if (auto explicitSpace = converted->dyn_cast_or_null()) + return explicitSpace.getInt(); + return failure(); +} + // Check if a memref type can be converted to a bare pointer. bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { if (type.isa()) @@ -406,7 +433,10 @@ Type elementType = convertType(type.getElementType()); if (!elementType) return {}; - return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); + FailureOr addressSpace = getMemRefAddressSpace(type); + if (failed(addressSpace)) + return {}; + return LLVM::LLVMPointerType::get(elementType, *addressSpace); } /// Convert an n-D vector type to an LLVM vector type: diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -390,10 +390,11 @@ ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); if (operandType.isa()) { - rewriter.replaceOp( - dimOp, {extractSizeOfUnrankedMemRef( - operandType, dimOp, adaptor.getOperands(), rewriter)}); - + FailureOr extractedSize = extractSizeOfUnrankedMemRef( + operandType, dimOp, adaptor.getOperands(), rewriter); + if (failed(extractedSize)) + return failure(); + rewriter.replaceOp(dimOp, {*extractedSize}); return success(); } if (operandType.isa()) { @@ -406,15 +407,23 @@ } private: - Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + FailureOr + extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); - unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); + FailureOr maybeAddressSpace = + getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType); + if (failed(maybeAddressSpace)) { + dimOp.emitOpError("memref memory space must be convertible to an integer " + "address space"); + return failure(); + } + unsigned addressSpace = *maybeAddressSpace; // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP @@ -439,7 +448,7 @@ loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, idxPlusOne); - return rewriter.create(loc, sizePtr); + return rewriter.create(loc, sizePtr).getResult(); } std::optional getConstantDimIndex(memref::DimOp dimOp) const { @@ -656,10 +665,14 @@ } uint64_t alignment = global.getAlignment().value_or(0); - + FailureOr addressSpace = + getTypeConverter()->getMemRefAddressSpace(type); + if (failed(addressSpace)) + return global.emitOpError( + "memory space cannot be converted to an integer address space"); auto newGlobal = rewriter.replaceOpWithNewOp( global, arrayTy, global.getConstant(), linkage, global.getSymName(), - initialValue, alignment, type.getMemorySpaceAsInt()); + initialValue, alignment, *addressSpace); if (!global.isExternal() && global.isUninitialized()) { Block *blk = new Block(); newGlobal.getInitializerRegion().push_back(blk); @@ -687,7 +700,10 @@ Operation *op) const override { auto getGlobalOp = cast(op); MemRefType type = getGlobalOp.getResult().getType().cast(); - unsigned memSpace = type.getMemorySpaceAsInt(); + + // This is called after a type conversion, which would have failed if this + // call fails. + unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type); Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( @@ -1066,8 +1082,9 @@ return; } - unsigned memorySpace = - operandType.cast().getMemorySpaceAsInt(); + // These will all cause assert()s on unconvertible types. + unsigned memorySpace = *typeConverter.getMemRefAddressSpace( + operandType.cast()); Type elementType = operandType.cast().getElementType(); Type llvmElementType = typeConverter.convertType(elementType); Type elementPtrPtrType = LLVM::LLVMPointerType::get( @@ -1267,7 +1284,8 @@ // Extract address space and element type. auto targetType = reshapeOp.getResult().getType().cast(); - unsigned addressSpace = targetType.getMemorySpaceAsInt(); + unsigned addressSpace = + *getTypeConverter()->getMemRefAddressSpace(targetType); Type elementType = targetType.getElementType(); // Create the unranked memref descriptor that holds the ranked one. The @@ -1825,10 +1843,10 @@ // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.getSource().getType().cast(); + unsigned sourceMemorySpace = + *getTypeConverter()->getMemRefAddressSpace(srcMemRefType); Value bitcastPtr = rewriter.create( - loc, - LLVM::LLVMPointerType::get(targetElementTy, - srcMemRefType.getMemorySpaceAsInt()), + loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1837,9 +1855,7 @@ alignedPtr = rewriter.create( loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift()); bitcastPtr = rewriter.create( - loc, - LLVM::LLVMPointerType::get(targetElementTy, - srcMemRefType.getMemorySpaceAsInt()), + loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -572,16 +572,24 @@ Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), adaptor.getDstIndices(), rewriter); auto i8Ty = IntegerType::get(op.getContext(), 8); - auto dstPointerType = - LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt()); + FailureOr dstAddressSpace = + getTypeConverter()->getMemRefAddressSpace(dstMemrefType); + if (failed(dstAddressSpace)) + return rewriter.notifyMatchFailure( + loc, "destination memref address space not convertible to integer"); + auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, *dstAddressSpace); dstPtr = rewriter.create(loc, dstPointerType, dstPtr); auto srcMemrefType = op.getSrc().getType().cast(); + FailureOr srcAddressSpace = + getTypeConverter()->getMemRefAddressSpace(srcMemrefType); + if (failed(srcAddressSpace)) + return rewriter.notifyMatchFailure( + loc, "source memref address space not convertible to integer"); Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), adaptor.getSrcIndices(), rewriter); - auto srcPointerType = - LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt()); + auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, *srcAddressSpace); scrPtr = rewriter.create(loc, srcPointerType, scrPtr); // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -682,8 +682,17 @@ if (failed(warpMatrixInfo)) return failure(); + Attribute memorySpace = + op.getSource().getType().cast().getMemorySpace(); + bool isSourceWorkgroupShared = false; + if (auto integerAttr = memorySpace.dyn_cast_or_null()) + isSourceWorkgroupShared = integerAttr.getInt() == 3; + else if (auto gpuAddrSpaceAttr = + memorySpace.dyn_cast_or_null()) + isSourceWorkgroupShared = + gpuAddrSpaceAttr.getValue() == gpu::AddressSpace::Workgroup; bool isLdMatrixCompatible = - op.getSource().getType().cast().getMemorySpaceAsInt() == 3 && + isSourceWorkgroupShared && nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; VectorType vecTy = op.getVectorType(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -92,21 +93,25 @@ } // Check if the last stride is non-unit or the memory space is not zero. -static LogicalResult isMemRefTypeSupported(MemRefType memRefType) { +static LogicalResult isMemRefTypeSupported(MemRefType memRefType, + LLVMTypeConverter &converter) { int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); - if (failed(successStrides) || strides.back() != 1 || - memRefType.getMemorySpaceAsInt() != 0) + FailureOr addressSpace = + converter.getMemRefAddressSpace(memRefType); + if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || + *addressSpace != 0) return failure(); return success(); } // Add an index vector component to a base pointer. static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, + LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, uint64_t vLen) { - assert(succeeded(isMemRefTypeSupported(memRefType)) && + assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && "unsupported memref type"); auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); auto ptrsType = LLVM::getFixedVectorType(pType, vLen); @@ -116,8 +121,10 @@ // Casts a strided element pointer to a vector pointer. The vector pointer // will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, MemRefType memRefType, Type vt) { - auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); + Value ptr, MemRefType memRefType, Type vt, + LLVMTypeConverter &converter) { + unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); + auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); return rewriter.create(loc, pType, ptr); } @@ -245,7 +252,8 @@ .template cast(); Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), adaptor.getIndices(), rewriter); - Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); + Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, + *this->getTypeConverter()); replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); return success(); @@ -264,7 +272,7 @@ MemRefType memRefType = gather.getBaseType().dyn_cast(); assert(memRefType && "The base should be bufferized"); - if (failed(isMemRefTypeSupported(memRefType))) + if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return failure(); auto loc = gather->getLoc(); @@ -283,8 +291,8 @@ if (!llvmNDVectorTy.isa()) { auto vType = gather.getVectorType(); // Resolve address. - Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr, - adaptor.getIndexVec(), + Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), + memRefType, base, ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( @@ -293,11 +301,14 @@ return success(); } - auto callback = [align, memRefType, base, ptr, loc, &rewriter]( - Type llvm1DVectorTy, ValueRange vectorOperands) { + LLVMTypeConverter &typeConverter = *this->getTypeConverter(); + auto callback = [align, memRefType, base, ptr, loc, &rewriter, + &typeConverter](Type llvm1DVectorTy, + ValueRange vectorOperands) { // Resolve address. Value ptrs = getIndexedPtrs( - rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0], + rewriter, loc, typeConverter, memRefType, base, ptr, + /*index=*/vectorOperands[0], LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); // Create the gather intrinsic. return rewriter.create( @@ -323,7 +334,7 @@ auto loc = scatter->getLoc(); MemRefType memRefType = scatter.getMemRefType(); - if (failed(isMemRefTypeSupported(memRefType))) + if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return failure(); // Resolve alignment. @@ -335,9 +346,9 @@ VectorType vType = scatter.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); - Value ptrs = - getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr, - adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); + Value ptrs = getIndexedPtrs( + rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), + ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -48,7 +49,16 @@ template static LogicalResult verifyRawBufferOp(T &op) { MemRefType bufferType = op.getMemref().getType().template cast(); - if (bufferType.getMemorySpaceAsInt() != 0) + Attribute memorySpace = bufferType.getMemorySpace(); + bool isGlobal = false; + if (!memorySpace) + isGlobal = true; + else if (auto intMemorySpace = memorySpace.dyn_cast()) + isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; + else if (auto gpuMemorySpace = memorySpace.dyn_cast()) + isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; + + if (!isGlobal) return op.emitOpError( "Buffer ops must operate on a memref in global memory"); if (!bufferType.hasRank()) diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -11,6 +11,8 @@ LINK_LIBS PUBLIC MLIRArithDialect + # Needed for GPU address space enum definition + MLIRGPUOps MLIRIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -52,7 +52,6 @@ Transforms/SerializeToBlob.cpp Transforms/SerializeToCubin.cpp Transforms/SerializeToHsaco.cpp - Transforms/LowerMemorySpaceAttributes.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU diff --git a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp +++ /dev/null @@ -1,184 +0,0 @@ -//===- LowerMemorySpaceAttributes.cpp ------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// Implementation of a pass that rewrites the IR so that uses of -/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced -/// with caller-specified numeric values. -/// -//===----------------------------------------------------------------------===// -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/Debug.h" - -namespace mlir { -#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS -#include "mlir/Dialect/GPU/Transforms/Passes.h.inc" -} // namespace mlir - -using namespace mlir; -using namespace mlir::gpu; - -//===----------------------------------------------------------------------===// -// Conversion Target -//===----------------------------------------------------------------------===// - -/// Returns true if the given `type` is considered as legal during memory space -/// attribute lowering. -static bool isLegalType(Type type) { - if (auto memRefType = type.dyn_cast()) { - return !memRefType.getMemorySpace() - .isa_and_nonnull(); - } - return true; -} - -/// Returns true if the given `attr` is considered legal during memory space -/// attribute lowering. -static bool isLegalAttr(Attribute attr) { - if (auto typeAttr = attr.dyn_cast()) - return isLegalType(typeAttr.getValue()); - return true; -} - -/// Returns true if the given `op` is legal during memory space attribute -/// lowering. -static bool isLegalOp(Operation *op) { - if (auto funcOp = dyn_cast(op)) { - return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) && - llvm::all_of(funcOp.getResultTypes(), isLegalType) && - llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(), - isLegalType); - } - - auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) { - return attr.getValue(); - }); - - return llvm::all_of(op->getOperandTypes(), isLegalType) && - llvm::all_of(op->getResultTypes(), isLegalType) && - llvm::all_of(attrs, isLegalAttr); -} - -void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) { - target.markUnknownOpDynamicallyLegal(isLegalOp); -} - -//===----------------------------------------------------------------------===// -// Type Converter -//===----------------------------------------------------------------------===// - -IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { - return IntegerAttr::get(IntegerType::get(ctx, 64), space); -} - -void mlir::gpu::populateMemorySpaceAttributeTypeConversions( - TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { - typeConverter.addConversion([mapping](Type type) -> std::optional { - auto subElementType = type.dyn_cast_or_null(); - if (!subElementType) - return type; - Type newType = subElementType.replaceSubElements( - [mapping](Attribute attr) -> std::optional { - auto memorySpaceAttr = attr.dyn_cast_or_null(); - if (!memorySpaceAttr) - return std::nullopt; - auto newValue = wrapNumericMemorySpace( - attr.getContext(), mapping(memorySpaceAttr.getValue())); - return newValue; - }); - return newType; - }); -} - -namespace { - -/// Converts any op that has operands/results/attributes with numeric MemRef -/// memory spaces. -struct LowerMemRefAddressSpacePattern final : public ConversionPattern { - LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter) - : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - SmallVector newAttrs; - newAttrs.reserve(op->getAttrs().size()); - for (auto attr : op->getAttrs()) { - if (auto typeAttr = attr.getValue().dyn_cast()) { - auto newAttr = getTypeConverter()->convertType(typeAttr.getValue()); - newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); - } else { - newAttrs.push_back(attr); - } - } - - SmallVector newResults; - (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults); - - OperationState state(op->getLoc(), op->getName().getStringRef(), operands, - newResults, newAttrs, op->getSuccessors()); - - for (Region ®ion : op->getRegions()) { - Region *newRegion = state.addRegion(); - rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); - TypeConverter::SignatureConversion result(newRegion->getNumArguments()); - (void)getTypeConverter()->convertSignatureArgs( - newRegion->getArgumentTypes(), result); - rewriter.applySignatureConversion(newRegion, result); - } - - Operation *newOp = rewriter.create(state); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; -} // namespace - -void mlir::gpu::populateMemorySpaceLoweringPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(patterns.getContext(), - typeConverter); -} - -namespace { -class LowerMemorySpaceAttributesPass - : public mlir::impl::GPULowerMemorySpaceAttributesPassBase< - LowerMemorySpaceAttributesPass> { -public: - using Base::Base; - void runOnOperation() override { - MLIRContext *context = &getContext(); - Operation *op = getOperation(); - - ConversionTarget target(getContext()); - populateLowerMemorySpaceOpLegality(target); - - TypeConverter typeConverter; - typeConverter.addConversion([](Type t) { return t; }); - populateMemorySpaceAttributeTypeConversions( - typeConverter, [this](AddressSpace space) -> unsigned { - switch (space) { - case AddressSpace::Global: - return globalAddrSpace; - case AddressSpace::Workgroup: - return workgroupAddrSpace; - case AddressSpace::Private: - return privateAddrSpace; - } - llvm_unreachable("unknown address space enum value"); - return 0; - }); - RewritePatternSet patterns(context); - populateMemorySpaceLoweringPatterns(typeConverter, patterns); - if (failed(applyFullConversion(op, target, std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3055,6 +3055,54 @@ return conversion; } +//===----------------------------------------------------------------------===// +// Type attribute conversion +//===----------------------------------------------------------------------===// +TypeConverter::AttributeConversionResult +TypeConverter::AttributeConversionResult::result(Attribute attr) { + return AttributeConversionResult(attr, resultTag); +} + +TypeConverter::AttributeConversionResult +TypeConverter::AttributeConversionResult::na() { + return AttributeConversionResult(nullptr, naTag); +} + +TypeConverter::AttributeConversionResult +TypeConverter::AttributeConversionResult::abort() { + return AttributeConversionResult(nullptr, abortTag); +} + +bool TypeConverter::AttributeConversionResult::hasResult() const { + return impl.getInt() == resultTag; +} + +bool TypeConverter::AttributeConversionResult::isNa() const { + return impl.getInt() == naTag; +} + +bool TypeConverter::AttributeConversionResult::isAbort() const { + return impl.getInt() == abortTag; +} + +Attribute TypeConverter::AttributeConversionResult::getResult() const { + assert(hasResult() && "Cannot get result from N/A or abort"); + return impl.getPointer(); +} + +Optional TypeConverter::convertTypeAttribute(Type type, + Attribute attr) { + for (TypeAttributeConversionCallbackFn &fn : + llvm::reverse(typeAttributeConversions)) { + AttributeConversionResult res = fn(type, attr); + if (res.hasResult()) + return res.getResult(); + if (res.isAbort()) + return std::nullopt; + } + return std::nullopt; +} + //===----------------------------------------------------------------------===// // FunctionOpInterfaceSignatureConversion //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir b/mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=CHECK,ROCDL +// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=CHECK,NVVM + +gpu.module @kernel { + gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space> + gpu.return + } +} + +// CHECK-LABEL: llvm.func @private +// CHECK: llvm.store +// ROCDL-SAME: : !llvm.ptr +// NVVM-SAME: : !llvm.ptr + + +// ----- + +gpu.module @kernel { + gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space> + gpu.return + } +} + +// CHECK-LABEL: llvm.func @workgroup +// CHECK: llvm.store +// CHECK-SAME: : !llvm.ptr + +// ----- + +gpu.module @kernel { + gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space>, #gpu.address_space>) -> f32 { + %c0 = arith.constant 0 : index + %inner = memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space>, #gpu.address_space> + %value = memref.load %inner[%c0] : memref<4xf32, #gpu.address_space> + gpu.return %value : f32 + } +} + +// CHECK-LABEL: llvm.func @nested_memref +// CHECK: llvm.load +// CHECK-SAME: : !llvm.ptr<{{.*}}, 1> +// CHECK: [[value:%.+]] = llvm.load +// CHECK-SAME: : !llvm.ptr +// CHECK: llvm.return [[value]] diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -convert-memref-to-llvm -split-input-file 2>&1 | FileCheck %s +// Since the error is at an unknown location, we use FileCheck instead of +// -veri-y-diagnostics here + +// CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions. +// CHECK-LABEL: @bad_address_space +func.func @bad_address_space(%a: memref<2xindex, "foo">) { + %c0 = arith.constant 0 : index + // CHECK: memref.store + memref.store %c0, %a[%c0] : memref<2xindex, "foo"> + return +} + +// ----- diff --git a/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir b/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir deleted file mode 100644 --- a/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes | FileCheck %s -// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes="private=0 global=0" | FileCheck %s --check-prefix=CUDA - -gpu.module @kernel { - gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space>) { - %c0 = arith.constant 0 : index - memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space> - gpu.return - } -} - -// CHECK: gpu.func @private -// CHECK-SAME: private(%{{.+}}: memref<4xf32, 5>) -// CHECK: memref.store -// CHECK-SAME: : memref<4xf32, 5> - -// CUDA: gpu.func @private -// CUDA-SAME: private(%{{.+}}: memref<4xf32>) -// CUDA: memref.store -// CUDA-SAME: : memref<4xf32> - -// ----- - -gpu.module @kernel { - gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space>) { - %c0 = arith.constant 0 : index - memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space> - gpu.return - } -} - -// CHECK: gpu.func @workgroup -// CHECK-SAME: workgroup(%{{.+}}: memref<4xf32, 3>) -// CHECK: memref.store -// CHECK-SAME: : memref<4xf32, 3> - -// ----- - -gpu.module @kernel { - gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space>, #gpu.address_space>) { - %c0 = arith.constant 0 : index - memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space>, #gpu.address_space> - gpu.return - } -} - -// CHECK: gpu.func @nested_memref -// CHECK-SAME: (%{{.+}}: memref<4xmemref<4xf32, 1>, 1>) -// CHECK: memref.load -// CHECK-SAME: : memref<4xmemref<4xf32, 1>, 1> - -// CUDA: gpu.func @nested_memref -// CUDA-SAME: (%{{.+}}: memref<4xmemref<4xf32>>) -// CUDA: memref.load -// CUDA-SAME: : memref<4xmemref<4xf32>>