diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -181,10 +181,13 @@ static unsigned getNumUnpackedValues() { return 2; } /// Builds IR computing the sizes in bytes (suitable for opaque allocation) - /// and appends the corresponding values into `sizes`. + /// and appends the corresponding values into `sizes`. `addressSpaces` + /// which must have the same length as `values`, is needed to handle layouts + /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)). static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, + ArrayRef addressSpaces, SmallVectorImpl &sizes); /// TODO: The following accessors don't take alignment rules between elements @@ -210,6 +213,12 @@ Value memRefDescPtr, Type elemPtrPtrType, Value alignedPtr); + /// Builds IR for getting the pointer to the offset's location. + /// Returns a pointer to a convertType(index), which points to the beggining + /// of a struct {index, index[rank], index[rank]}. + static Value offsetBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, Type elemPtrPtrType); /// Builds IR extracting the offset from the descriptor. static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -110,6 +110,60 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// AddressSpaceCastOp +//===----------------------------------------------------------------------===// +def MemRef_AddressSpaceCastOp : MemRef_Op<"address_space_cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + MemRefsNormalizable, + Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + ViewLikeOpInterface + ]> { + let summary = "memref address space cast operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `memref.address_space_cast` ssa-use `:` type `to` type + ``` + + This operation casts memref values between adress spaces. + The input and result will be memrefs of the same types and shape that alias + the same underlying memory, though, for some casts on some targets, + the underlying values of the pointer stored in the memref may be affected + by the cast. + + The input and result must have the same shape, element type, rank, and layout. + + If the source and target address spaces are the same, this operation is a noop. + + Example: + + ```mlir + // Cast a GPU private memory attribution into a generic pointer + %2 = memref.address_space_cast %1 : memref to memref + // Cast a generic pointer to workgroup-local memory + %4 = memref.address_space_cast %3 : memref<5x4xi32> to memref<5x34xi32, 3> + // Cast between two non-default memory spaces + %6 = memref.address_space_cast %5 + : memref<*xmemref, 5> to memref<*xmemref, 3> + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$source); + let results = (outs AnyRankedOrUnrankedMemRef:$dest); + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + + let extraClassDeclaration = [{ + Value getViewSource() { return getSource(); } + }]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -329,24 +329,26 @@ void UnrankedMemRefDescriptor::computeSizes( OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - ArrayRef values, SmallVectorImpl &sizes) { + ArrayRef values, ArrayRef addressSpaces, + SmallVectorImpl &sizes) { if (values.empty()) return; - + assert(values.size() == addressSpaces.size() && + "must provide address space for each descriptor"); // Cache the index type. Type indexType = typeConverter.getIndexType(); // Initialize shared constants. Value one = createIndexAttrConstant(builder, loc, indexType, 1); Value two = createIndexAttrConstant(builder, loc, indexType, 2); - Value pointerSize = createIndexAttrConstant( - builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); Value indexSize = createIndexAttrConstant(builder, loc, indexType, ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); sizes.reserve(sizes.size() + values.size()); - for (UnrankedMemRefDescriptor desc : values) { + for (auto pair : llvm::zip(values, addressSpaces)) { + UnrankedMemRefDescriptor desc = std::get<0>(pair); + unsigned addressSpace = std::get<1>(pair); // Emit IR computing the memory necessary to store the descriptor. This // assumes the descriptor to be // { type*, type*, index, index[rank], index[rank] } @@ -354,6 +356,9 @@ // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). // TODO: consider including the actual size (including eventual padding due // to data layout) into the unranked descriptor. + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, + ceilDiv(typeConverter.getPointerBitwidth(addressSpace), 8)); Value doublePointerSize = builder.create(loc, indexType, two, pointerSize); @@ -415,10 +420,10 @@ builder.create(loc, alignedPtr, alignedGep); } -Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - Type elemPtrPtrType) { +Value UnrankedMemRefDescriptor::offsetBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + Type elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -426,21 +431,25 @@ loc, elemPtrPtrType, elementPtrPtr, ArrayRef{2}); offsetGep = builder.create( loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); - return builder.create(loc, offsetGep); + return offsetGep; +} + +Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + Type elemPtrPtrType) { + Value offsetPtr = + offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrPtrType); + return builder.create(loc, offsetPtr); } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType, Value offset) { - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - - Value offsetGep = builder.create( - loc, elemPtrPtrType, elementPtrPtr, ArrayRef{2}); - offsetGep = builder.create( - loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); - builder.create(loc, offset, offsetGep); + Value offsetPtr = + offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrPtrType); + builder.create(loc, offset, offsetPtr); } Value UnrankedMemRefDescriptor::sizeBasePtr( diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -227,9 +227,13 @@ // Find operands of unranked memref type and store them. SmallVector unrankedMemrefs; - for (unsigned i = 0, e = operands.size(); i < e; ++i) - if (origTypes[i].isa()) + SmallVector unrankedAddressSpaces; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (auto memRefType = origTypes[i].dyn_cast()) { unrankedMemrefs.emplace_back(operands[i]); + unrankedAddressSpaces.emplace_back(memRefType.getMemorySpaceAsInt()); + } + } if (unrankedMemrefs.empty()) return success(); @@ -237,7 +241,8 @@ // Compute allocation sizes. SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), - unrankedMemrefs, sizes); + unrankedMemrefs, unrankedAddressSpaces, + sizes); // Get frequently used types. MLIRContext *context = builder.getContext(); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/SmallBitVector.h" #include @@ -804,6 +805,136 @@ } }; +struct AddressSpareCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::AddressSpaceCastOp>::ConvertOpToLLVMPattern; + + FailureOr getMemorySpace(Type type) const { + Attribute memorySpace; + if (auto mR = type.dyn_cast()) + memorySpace = mR.getMemorySpace(); + else if (auto mU = type.dyn_cast()) + memorySpace = mU.getMemorySpace(); + else + return failure(); + + // Default memory space is 0 + if (!memorySpace) + return 0u; + auto memorySpaceInt = memorySpace.dyn_cast(); + if (!memorySpaceInt) + return failure(); + return static_cast(memorySpaceInt.getInt()); + } + + LogicalResult + matchAndRewrite(memref::AddressSpaceCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type resultType = op.getDest().getType(); + if (auto resultTypeR = resultType.dyn_cast()) { + auto resultDescType = + typeConverter->convertType(resultTypeR).cast(); + Type newPtrType = resultDescType.getBody()[0]; + + SmallVector descVals; + MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, + descVals); + descVals[0] = + rewriter.create(loc, newPtrType, descVals[0]); + descVals[1] = + rewriter.create(loc, newPtrType, descVals[1]); + Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), + resultTypeR, descVals); + rewriter.replaceOp(op, result); + return success(); + } + if (auto resultTypeU = resultType.dyn_cast()) { + // Since the type converter won't be doing this for us, get the address + // space + auto sourceType = op.getSource().getType().cast(); + FailureOr maybeSourceAddrSpace = getMemorySpace(sourceType); + if (failed(maybeSourceAddrSpace)) + return rewriter.notifyMatchFailure(loc, + "non-integer source address space"); + unsigned sourceAddrSpace = *maybeSourceAddrSpace; + FailureOr maybeResultAddrSpace = getMemorySpace(resultTypeU); + if (failed(maybeResultAddrSpace)) + return rewriter.notifyMatchFailure(loc, + "non-integer result address space"); + unsigned resultAddrSpace = *maybeResultAddrSpace; + + UnrankedMemRefDescriptor sourceDesc(adaptor.getSource()); + Value rank = sourceDesc.rank(rewriter, loc); + Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc); + + // Create and allocate storage for new memref descriptor. + auto result = UnrankedMemRefDescriptor::undef( + rewriter, loc, typeConverter->convertType(resultTypeU)); + result.setRank(rewriter, loc, rank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), + result, resultAddrSpace, sizes); + Value resultUnderlyingSize = sizes.front(); + Value resultUnderlyingDesc = rewriter.create( + loc, getVoidPtrType(), resultUnderlyingSize, std::nullopt); + result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); + + // Copy pointers, performing address space casts. + Type llvmElementType = + typeConverter->convertType(sourceType.getElementType()); + auto sourceElemPtrPtrType = LLVM::LLVMPointerType::get( + LLVM::LLVMPointerType::get(llvmElementType, sourceAddrSpace)); + auto resultElemPtrType = + LLVM::LLVMPointerType::get(llvmElementType, resultAddrSpace); + auto resultElemPtrPtrType = LLVM::LLVMPointerType::get(resultElemPtrType); + + Value allocatedPtr = sourceDesc.allocatedPtr( + rewriter, loc, sourceUnderlyingDesc, sourceElemPtrPtrType); + Value alignedPtr = + sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), + sourceUnderlyingDesc, sourceElemPtrPtrType); + allocatedPtr = rewriter.create( + loc, resultElemPtrType, allocatedPtr); + alignedPtr = rewriter.create( + loc, resultElemPtrType, alignedPtr); + + result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, + resultElemPtrPtrType, allocatedPtr); + result.setAlignedPtr(rewriter, loc, *getTypeConverter(), + resultUnderlyingDesc, resultElemPtrPtrType, + alignedPtr); + + // Copy all the index-valued operands + Value sourceIndexVals = + sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(), + sourceUnderlyingDesc, sourceElemPtrPtrType); + Value resultIndexVals = + result.offsetBasePtr(rewriter, loc, *getTypeConverter(), + resultUnderlyingDesc, resultElemPtrPtrType); + + int64_t bytesToSkip = + 2 * + ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); + Value bytesToSkipConst = rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); + Value copySize = rewriter.create( + loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); + Type llvmBool = typeConverter->convertType(rewriter.getI1Type()); + Value nonVolatile = rewriter.create( + loc, llvmBool, rewriter.getBoolAttr(false)); + rewriter.create(loc, resultIndexVals, sourceIndexVals, + copySize, nonVolatile); + + rewriter.replaceOp(op, ValueRange{result}); + return success(); + } + return rewriter.notifyMatchFailure(loc, "unexpected memref type"); + } +}; + struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1277,7 +1408,7 @@ targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - targetDesc, sizes); + targetDesc, addressSpace, sizes); Value underlyingDescPtr = rewriter.create( loc, getVoidPtrType(), sizes.front(), std::nullopt); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); @@ -2002,6 +2133,7 @@ RewritePatternSet &patterns) { // clang-format off patterns.add< + AddressSpareCastOpLowering, AllocaOpLowering, AllocaScopeOpLowering, AtomicRMWOpLowering, diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -158,6 +158,17 @@ namespace { +/// Converts memref.address_space_cast to the appropriate spirv cast operations. +class AddressSpaceCastOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AddressSpaceCastOp addrCastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts memref.alloca to SPIR-V Function variables. class AllocaOpPattern final : public OpConversionPattern { public: @@ -234,6 +245,74 @@ } // namespace +//===----------------------------------------------------------------------===// +// AddressSpaceCastOp +//===----------------------------------------------------------------------===// + +LogicalResult AddressSpaceCastOpPattern::matchAndRewrite( + memref::AddressSpaceCastOp addrCastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = addrCastOp.getLoc(); + auto &typeConverter = *getTypeConverter(); + if (!typeConverter.allows(spirv::Capability::Kernel)) + return rewriter.notifyMatchFailure( + loc, "address space casts require kernel capability"); + + auto sourceType = addrCastOp.getSource().getType().dyn_cast(); + if (!sourceType) + return rewriter.notifyMatchFailure( + loc, "SPIR-V lowering requires ranked memref types"); + auto resultType = addrCastOp.getResult().getType().cast(); + + auto sourceStorageClassAttr = + sourceType.getMemorySpace().dyn_cast_or_null(); + if (!sourceStorageClassAttr) + return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { + diag << "source address space " << sourceType.getMemorySpace() + << " must be a SPIR-V storage class"; + }); + auto resultStorageClassAttr = + resultType.getMemorySpace().dyn_cast_or_null(); + if (!resultStorageClassAttr) + return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { + diag << "result address space " << resultType.getMemorySpace() + << " must be a SPIR-V storage class"; + }); + + spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); + spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); + + Value result = adaptor.getSource(); + Type resultPtrType = typeConverter.convertType(resultType); + Type genericPtrType = resultPtrType; + // SPIR-V doesn't have a general address space cast operation. Instead, it has + // conversions to and from generic pointers. To implement the general case, + // we use specific-to-generic conversions when the source class is not + // generic. Then when the result storage class is not generic, we convert the + // generic pointer (either the input on ar intermediate result) to theat + // class. This also means that we'll need the intermediate generic pointer + // type if neither the source or destination have it. + if (sourceSc != spirv::StorageClass::Generic && + resultSc != spirv::StorageClass::Generic) { + Type intermediateType = + MemRefType::get(sourceType.getShape(), sourceType.getElementType(), + sourceType.getLayout(), + rewriter.getAttr( + spirv::StorageClass::Generic)); + genericPtrType = typeConverter.convertType(intermediateType); + } + if (sourceSc != spirv::StorageClass::Generic) { + result = + rewriter.create(loc, genericPtrType, result); + } + if (resultSc != spirv::StorageClass::Generic) { + result = + rewriter.create(loc, resultPtrType, result); + } + rewriter.replaceOp(addrCastOp, result); + return success(); +} + //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// @@ -576,9 +655,9 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -224,6 +224,55 @@ return strides; } +//===----------------------------------------------------------------------===// +// AddressSpaceCastOp +//===----------------------------------------------------------------------===// + +void AddressSpaceCastOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "addrcast"); +} + +bool AddressSpaceCastOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + + auto uaT = a.dyn_cast(); + auto ubT = b.dyn_cast(); + + if (aT && bT) { + if (aT.getElementType() != bT.getElementType()) + return false; + if (aT.getLayout() != bT.getLayout()) + return false; + if (aT.getRank() != bT.getRank()) + return false; + if (aT.getShape() != bT.getShape()) + return false; + return true; + } + if (uaT && ubT) { + // The style from above led to a lint + bool areCompatible = uaT.getElementType() == ubT.getElementType(); + return areCompatible; + } + return false; +} + +OpFoldResult AddressSpaceCastOp::fold(ArrayRef operands) { + // address_space_cast(address_space_cast(v, t1), t2) -> address_space_cast(v, + // t2) + if (auto parentCast = getSource().getDefiningOp()) { + getSourceMutable().assign(parentCast.getSource()); + return getResult(); + } + return Value{}; +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir @@ -122,9 +122,9 @@ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) // These sizes may depend on the data layout, not matching specific values. - // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)> // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] @@ -154,13 +154,12 @@ // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1] %0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32> - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) // These sizes may depend on the data layout, not matching specific values. - // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -263,6 +263,66 @@ // ----- +// FIXME: the *ToLLVM passes don't use information from data layouts +// to set address spaces, so the constants below don't reflect the layout +// Update this test once that data layout attribute works how we'd expect it to. +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry : vector<3xi32>>, + #dlti.dl_entry, dense<[32, 32, 32]> : vector<3xi32>>> } { + // CHECK-LABEL: @memref_address_space_cast + func.func @memref_address_space_cast(%input : memref<*xf32>) -> memref<*xf32, 1> { + %cast = memref.address_space_cast %input : memref<*xf32> to memref<*xf32, 1> + return %cast : memref<*xf32, 1> + } +} +// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}} to !llvm.struct<(i64, ptr)> +// CHECK: [[RANK:%.*]] = llvm.extractvalue [[INPUT]][0] : !llvm.struct<(i64, ptr)> +// CHECK: [[SOURCE_DESC:%.*]] = llvm.extractvalue [[INPUT]][1] +// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[RANK]], [[RESULT_0]][0] : !llvm.struct<(i64, ptr)> + +// Compute size in bytes to allocate result ranked descriptor +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] +// CHECK: [[DOUBLE_RANK:%.*]] = llvm.mul [[C2]], %{{.*}} +// CHECK: [[NUM_INDEX_VALS:%.*]] = llvm.add [[DOUBLE_RANK]], [[C1]] +// CHECK: [[INDEX_VALS_SIZE:%.*]] = llvm.mul [[NUM_INDEX_VALS]], [[INDEX_SIZE]] +// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], [[INDEX_VALS_SIZE]] +// CHECK: [[RESULT_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8 +// CHECK: llvm.insertvalue [[RESULT_DESC]], [[RESULT_1]][1] + +// Cast pointers +// CHECK: [[SOURCE_DESC_CAST_ALLOC:%.*]] = llvm.bitcast [[SOURCE_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[SOURCE_ALLOC:%.*]] = llvm.load [[SOURCE_DESC_CAST_ALLOC]] +// CHECK: [[SOURCE_DESC_CAST_ALIGN:%.*]] = llvm.bitcast [[SOURCE_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[SOURCE_ALIGN_GEP:%.*]] = llvm.getelementptr [[SOURCE_DESC_CAST_ALIGN]][1] +// CHECK: [[SOURCE_ALIGN:%.*]] = llvm.load [[SOURCE_ALIGN_GEP]] : !llvm.ptr> +// CHECK: [[RESULT_ALLOC:%.*]] = llvm.addrspacecast [[SOURCE_ALLOC]] : !llvm.ptr to !llvm.ptr +// CHECK: [[RESULT_ALIGN:%.*]] = llvm.addrspacecast [[SOURCE_ALIGN]] : !llvm.ptr to !llvm.ptr +// CHECK: [[RESULT_DESC_CAST_ALLOC:%.*]] = llvm.bitcast [[RESULT_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: llvm.store [[RESULT_ALLOC]], [[RESULT_DESC_CAST_ALLOC]] : !llvm.ptr> +// CHECK: [[RESULT_DESC_CAST_ALIGN:%.*]] = llvm.bitcast [[RESULT_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[RESULT_ALIGN_GEP:%.*]] = llvm.getelementptr [[RESULT_DESC_CAST_ALIGN]][1] +// CHECK: llvm.store [[RESULT_ALIGN]], [[RESULT_ALIGN_GEP]] : !llvm.ptr> + +// Memcpy remaniing values + +// CHECK: [[SOURCE_DESC_CAST_OFFSET:%.*]] = llvm.bitcast [[SOURCE_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[SOURCE_OFFSET_GEP:%.*]] = llvm.getelementptr [[SOURCE_DESC_CAST_OFFSET]][2] +// CHECK: [[SOURCE_OFFSET_IDX:%.*]] = llvm.bitcast [[SOURCE_OFFSET_GEP]] : !llvm.ptr> to !llvm.ptr +// CHECK: [[RESULT_DESC_CAST_OFFSET:%.*]] = llvm.bitcast [[RESULT_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[RESULT_OFFSET_GEP:%.*]] = llvm.getelementptr [[RESULT_DESC_CAST_OFFSET]][2] +// CHECK: [[RESULT_OFFSET_IDX:%.*]] = llvm.bitcast [[RESULT_OFFSET_GEP]] : !llvm.ptr> to !llvm.ptr +// CHECK: [[SIZEOF_TWO_RESULT_PTRS:%.*]] = llvm.mlir.constant(16 : index) : i64 +// CHECK: [[COPY_SIZE:%.*]] = llvm.sub [[DESC_ALLOC_SIZE]], [[SIZEOF_TWO_RESULT_PTRS]] +// CHECK: [[FALSE:%.*]] = llvm.mlir.constant(false) : i1 +// CHECK: "llvm.intr.memcpy"([[RESULT_OFFSET_IDX]], [[SOURCE_OFFSET_IDX]], [[COPY_SIZE]], [[FALSE]]) + +// ----- + // CHECK-LABEL: func @memref_cast_static_to_dynamic func.func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { // CHECK-NOT: llvm.bitcast @@ -481,8 +541,8 @@ // Compute size in bytes to allocate result ranked descriptor // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8 diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -424,3 +424,27 @@ %out = memref.realloc %in {alignment = 8} : memref<2xf32> to memref<4xf32> return %out : memref<4xf32> } + +// ----- + +// CHECK-LABEL: @memref_address_space_cast +func.func @memref_address_space_cast(%input : memref) -> memref { + %cast = memref.address_space_cast %input : memref to memref + return %cast : memref +} +// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}} +// CHECK: [[ALLOC:%.*]] = llvm.extractvalue [[INPUT]][0] +// CHECK: [[ALIGN:%.*]] = llvm.extractvalue [[INPUT]][1] +// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] +// CHECK: [[SIZE:%.*]] = llvm.extractvalue [[INPUT]][3, 0] +// CHECK: [[STRIDE:%.*]] = llvm.extractvalue [[INPUT]][4, 0] +// CHECK: [[CAST_ALLOC:%.*]] = llvm.addrspacecast [[ALLOC]] : !llvm.ptr to !llvm.ptr +// CHECK: [[CAST_ALIGN:%.*]] = llvm.addrspacecast [[ALIGN]] : !llvm.ptr to !llvm.ptr +// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef +// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[CAST_ALLOC]], [[RESULT_0]][0] +// CHECK: [[RESULT_2:%.*]] = llvm.insertvalue [[CAST_ALIGN]], [[RESULT_1]][1] +// CHECK: [[RESULT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[RESULT_2]][2] +// CHECK: [[RESULT_4:%.*]] = llvm.insertvalue [[SIZE]], [[RESULT_3]][3, 0] +// CHECK: [[RESULT_5:%.*]] = llvm.insertvalue [[STRIDE]], [[RESULT_4]][4, 0] +// CHECK: [[RESULT:%.*]] = builtin.unrealized_conversion_cast [[RESULT_5]] : {{.*}} to memref +// CHECK: return [[RESULT]] diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -212,6 +212,32 @@ // ----- +// Check address space casts + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: func.func @address_space_cast +func.func @address_space_cast(%arg: memref<4xf32, #spirv.storage_class>) + -> memref<4xf32, #spirv.storage_class> { + // CHECK: %[[ARG_CAST:.+]] = builtin.unrealized_conversion_cast {{.*}} to !spirv.ptr, CrossWorkgroup> + // CHECK: %[[TO_GENERIC:.+]] = spirv.PtrCastToGeneric %[[ARG_CAST]] : !spirv.ptr, CrossWorkgroup> to !spirv.ptr, Generic> + // CHECK: %[[TO_PRIVATE:.+]] = spirv.GenericCastToPtr %[[TO_GENERIC]] : !spirv.ptr, Generic> to !spirv.ptr, Function> + // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[TO_PRIVATE]] + // CHECK: return %[[RET]] + %ret = memref.address_space_cast %arg : memref<4xf32, #spirv.storage_class> + to memref<4xf32, #spirv.storage_class> + return %ret : memref<4xf32, #spirv.storage_class> +} + +} // end module + +// ----- + // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. // TODO: Test i64 types. diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -894,3 +894,25 @@ to memref> return %1 : memref> } + +// ----- + +// CHECK-LABEL: func @fold_trivial_address_space_cast( +// CHECK-SAME: %[[arg:.*]]: memref +// CHECK: return %[[arg]] +func.func @fold_trivial_address_space_cast(%arg : memref) -> memref { + %0 = memref.address_space_cast %arg : memref to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @fold_multiple_address_space_cast( +// CHECK-SAME: %[[arg:.*]]: memref +// CHECK: %[[res:.*]] = memref.address_space_cast %[[arg]] : memref to memref +// CHECK: return %[[res]] +func.func @fold_multiple_address_space_cast(%arg : memref) -> memref { + %0 = memref.address_space_cast %arg : memref to memref + %1 = memref.address_space_cast %0 : memref to memref + return %1 : memref +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -380,3 +380,9 @@ %0 = memref.extract_aligned_pointer_as_index %src : memref -> index return %0 : index } + +// CHECK-LABEL: func @memref_address_space_cast +func.func @memref_address_space_cast(%src : memref) -> memref { + %dst = memref.address_space_cast %src : memref to memref + return %dst : memref +}